# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0


from collections import namedtuple
from functools import partial
import inspect
from itertools import product
import math
import os
from typing import Callable

import numpy as np
from numpy.testing import assert_allclose, assert_array_equal
import pytest
import scipy
from scipy.sparse import csr_matrix
import scipy.stats as osp

import jax
from jax import grad, lax, vmap
import jax.numpy as jnp
import jax.random as random
from jax.scipy.special import expit, logsumexp
from jax.scipy.stats import norm as jax_norm, truncnorm as jax_truncnorm

import numpyro.distributions as dist
from numpyro.distributions import (
    SineBivariateVonMises,
    constraints,
    kl_divergence,
    transforms,
)
from numpyro.distributions.batch_util import vmap_over
from numpyro.distributions.censored import (
    IntervalCensoredDistribution,
    LeftCensoredDistribution,
    RightCensoredDistribution,
)
from numpyro.distributions.discrete import _to_probs_bernoulli, _to_probs_multinom
from numpyro.distributions.flows import InverseAutoregressiveTransform
from numpyro.distributions.transforms import (
    LowerCholeskyAffine,
    PermuteTransform,
    PowerTransform,
    SimplexToOrderedTransform,
    SoftplusTransform,
    biject_to,
)
from numpyro.distributions.util import (
    matrix_to_tril_vec,
    multinomial,
    signed_stick_breaking_tril,
    sum_rightmost,
    vec_to_tril_matrix,
)
from numpyro.nn import AutoregressiveNN


def my_kron(A, B):
    D = A[..., :, None, :, None] * B[..., None, :, None, :]
    ds = D.shape
    newshape = (*ds[:-4], ds[-4] * ds[-3], ds[-2] * ds[-1])
    return D.reshape(newshape)


def _identity(x):
    return x


def _circ_mean(angles):
    return jnp.arctan2(
        jnp.mean(jnp.sin(angles), axis=0), jnp.mean(jnp.cos(angles), axis=0)
    )


def sde_fn1(x, _):
    lam = 0.1
    sigma2 = 0.1
    return lam * x, sigma2


def sde_fn2(xy, _):
    tau, a = 2.0, 1.1
    x, y = xy[0], xy[1]
    dx = tau * (x - x**3.0 / 3.0 + y)
    dy = (1.0 / tau) * (a - x)
    dxy = jnp.vstack([dx, dy]).reshape(xy.shape)

    sigma2 = 0.1
    return dxy, sigma2


class T(namedtuple("TestCase", ["jax_dist", "sp_dist", "params"])):
    def __new__(cls, jax_dist, *params):
        sp_dist = get_sp_dist(jax_dist)
        return super(cls, T).__new__(cls, jax_dist, sp_dist, params)


def _mvn_to_scipy(loc, cov, prec, tril):
    jax_dist = dist.MultivariateNormal(loc, cov, prec, tril)
    mean = jax_dist.mean
    cov = jax_dist.covariance_matrix
    return osp.multivariate_normal(mean=mean, cov=cov)


def _multivariate_t_to_scipy(df, loc, tril):
    if scipy.__version__ < "1.6.0":
        pytest.skip(
            "Multivariate Student-T distribution is not available in scipy < 1.6"
        )
    jax_dist = dist.MultivariateStudentT(df, loc, tril)
    mean = jax_dist.mean
    cov = jax_dist.covariance_matrix
    return osp.multivariate_t(loc=mean, shape=cov, df=df)


def _lowrank_mvn_to_scipy(loc, cov_fac, cov_diag):
    jax_dist = dist.LowRankMultivariateNormal(loc, cov_fac, cov_diag)
    mean = jax_dist.mean
    cov = jax_dist.covariance_matrix
    return osp.multivariate_normal(mean=mean, cov=cov)


def _truncnorm_to_scipy(loc, scale, low, high):
    if low is None:
        a = -np.inf
    else:
        a = (low - loc) / scale
    if high is None:
        b = np.inf
    else:
        b = (high - loc) / scale
    return osp.truncnorm(a, b, loc=loc, scale=scale)


def _wishart_to_scipy(conc, scale, rate, tril):
    jax_dist = dist.Wishart(conc, scale, rate, tril)
    if not jnp.isscalar(jax_dist.concentration):
        pytest.skip("scipy Wishart only supports a single scalar concentration")
    # Cast to float explicitly because np.isscalar returns False on scalar jax arrays.
    return osp.wishart(float(jax_dist.concentration), jax_dist.scale_matrix)


def _invwishart_to_scipy(conc, scale, rate, tril):
    jax_dist = dist.InverseWishart(conc, scale, rate, tril)
    if not jnp.isscalar(jax_dist.concentration):
        pytest.skip("scipy invwishart only supports a single scalar concentration")
    # Cast to float explicitly because np.isscalar returns False on scalar jax arrays.
    return osp.invwishart(float(jax_dist.concentration), jax_dist.scale_matrix)


def _circulant_to_scipy(loc, covariance_row, covariance_rfft):
    jax_dist = dist.CirculantNormal(loc, covariance_row, covariance_rfft)
    return osp.multivariate_normal(mean=jax_dist.mean, cov=jax_dist.covariance_matrix)


def _TruncatedNormal(loc, scale, low, high):
    return dist.TruncatedNormal(loc=loc, scale=scale, low=low, high=high)


def _TruncatedCauchy(loc, scale, low, high):
    return dist.TruncatedCauchy(loc=loc, scale=scale, low=low, high=high)


def _LeftCensoredHalfNormal(scale, censored):
    base_dist = dist.HalfNormal(scale)
    return LeftCensoredDistribution(base_dist, censored)


def _RightCensoredWeibull(scale, concentration, censored):
    base_dist = dist.Weibull(scale, concentration)
    return RightCensoredDistribution(base_dist, censored)


def _LeftCensoredNormal(loc, scale, censored):
    base_dist = dist.Normal(loc, scale)
    return LeftCensoredDistribution(base_dist, censored)


def _RightCensoredNormal(loc, scale, censored):
    base_dist = dist.Normal(loc, scale)
    return RightCensoredDistribution(base_dist, censored)


def _LeftCensoredPoisson(rate, censored):
    base_dist = dist.Poisson(rate)
    return LeftCensoredDistribution(base_dist, censored)


def _RightCensoredPoisson(rate, censored):
    base_dist = dist.Poisson(rate)
    return RightCensoredDistribution(base_dist, censored)


def _IntervalCensoredNormal(loc, scale, left_censored, right_censored):
    base_dist = dist.Normal(loc, scale)
    return IntervalCensoredDistribution(base_dist, left_censored, right_censored)


def _IntervalCensoredPoisson(rate, left_censored, right_censored):
    base_dist = dist.Poisson(rate)
    return IntervalCensoredDistribution(base_dist, left_censored, right_censored)


def _IntervalCensoredWeibull(scale, concentration, left_censored, right_censored):
    base_dist = dist.Weibull(scale, concentration)
    return IntervalCensoredDistribution(base_dist, left_censored, right_censored)


_TruncatedNormal.arg_constraints = {}
_TruncatedNormal.reparametrized_params = []
_TruncatedNormal.infer_shapes = lambda *args: (lax.broadcast_shapes(*args), ())


class SineSkewedUniform(dist.SineSkewed):
    def __init__(self, skewness, **kwargs):
        lower, upper = (np.array([-math.pi, -math.pi]), np.array([math.pi, math.pi]))
        base_dist = dist.Uniform(lower, upper, **kwargs).to_event(lower.ndim)
        super().__init__(base_dist, skewness, **kwargs)


@vmap_over.register
def _vmap_over_sine_skewed_uniform(self: SineSkewedUniform, skewness=None):
    return vmap_over.dispatch(dist.SineSkewed)(self, base_dist=None, skewness=skewness)


class SineSkewedVonMises(dist.SineSkewed):
    def __init__(self, skewness, **kwargs):
        von_loc, von_conc = (np.array([0.0]), np.array([1.0]))
        base_dist = dist.VonMises(von_loc, von_conc, **kwargs).to_event(von_loc.ndim)
        super().__init__(base_dist, skewness, **kwargs)


@vmap_over.register
def _vmap_over_sine_skewed_von_mises(self: SineSkewedVonMises, skewness=None):
    return vmap_over.dispatch(dist.SineSkewed)(self, base_dist=None, skewness=skewness)


class SineSkewedVonMisesBatched(dist.SineSkewed):
    def __init__(self, skewness, **kwargs):
        von_loc, von_conc = (np.array([0.0, -1.234]), np.array([1.0, 10.0]))
        base_dist = dist.VonMises(von_loc, von_conc, **kwargs).to_event(von_loc.ndim)
        super().__init__(base_dist, skewness, **kwargs)


@vmap_over.register
def _vmap_over_sine_skewed_von_mises_batched(
    self: SineSkewedVonMisesBatched, skewness=None
):
    return vmap_over.dispatch(dist.SineSkewed)(self, base_dist=None, skewness=skewness)


class _GaussianMixture(dist.MixtureSameFamily):
    arg_constraints = {}
    reparametrized_params = []

    def __init__(self, mixing_probs, loc, scale):
        component_dist = dist.Normal(loc=loc, scale=scale)
        mixing_distribution = dist.Categorical(probs=mixing_probs)
        super().__init__(
            mixing_distribution=mixing_distribution,
            component_distribution=component_dist,
        )

    @property
    def loc(self):
        return self.component_distribution.loc

    @property
    def scale(self):
        return self.component_distribution.scale


@vmap_over.register
def _vmap_over_gaussian_mixture(self: _GaussianMixture, loc=None, scale=None):
    component_distribution = vmap_over(
        self.component_distribution, loc=loc, scale=scale
    )
    return vmap_over.dispatch(dist.MixtureSameFamily)(
        self, _component_distribution=component_distribution
    )


class _Gaussian2DMixture(dist.MixtureSameFamily):
    arg_constraints = {}
    reparametrized_params = []

    def __init__(self, mixing_probs, loc, covariance_matrix):
        component_dist = dist.MultivariateNormal(
            loc=loc, covariance_matrix=covariance_matrix
        )
        mixing_distribution = dist.Categorical(probs=mixing_probs)
        super().__init__(
            mixing_distribution=mixing_distribution,
            component_distribution=component_dist,
        )

    @property
    def loc(self):
        return self.component_distribution.loc

    @property
    def covariance_matrix(self):
        return self.component_distribution.covariance_matrix


@vmap_over.register
def _vmap_over_gaussian_2d_mixture(self: _Gaussian2DMixture, loc=None):
    component_distribution = vmap_over(self.component_distribution, loc=loc)
    return vmap_over.dispatch(dist.MixtureSameFamily)(
        self, _component_distribution=component_distribution
    )


class _GeneralMixture(dist.MixtureGeneral):
    arg_constraints = {}
    reparametrized_params = []

    def __init__(self, mixing_probs, locs, scales):
        component_dists = [
            dist.Normal(loc=loc_, scale=scale_) for loc_, scale_ in zip(locs, scales)
        ]
        mixing_distribution = dist.Categorical(probs=mixing_probs)
        return super().__init__(
            mixing_distribution=mixing_distribution,
            component_distributions=component_dists,
        )

    @property
    def locs(self):
        # hotfix for vmapping tests, which cannot easily check non-array attributes
        return self.component_distributions[0].loc

    @property
    def scales(self):
        return self.component_distributions[0].scale


@vmap_over.register
def _vmap_over_general_mixture(self: _GeneralMixture, locs=None, scales=None):
    component_distributions = [
        vmap_over(d, loc=locs, scale=scales) for d in self.component_distributions
    ]
    return vmap_over.dispatch(dist.MixtureGeneral)(
        self, _component_distributions=component_distributions
    )


class _General2DMixture(dist.MixtureGeneral):
    arg_constraints = {}
    reparametrized_params = []

    def __init__(self, mixing_probs, locs, covariance_matrices):
        component_dists = [
            dist.MultivariateNormal(loc=loc_, covariance_matrix=covariance_matrix)
            for loc_, covariance_matrix in zip(locs, covariance_matrices)
        ]
        mixing_distribution = dist.Categorical(probs=mixing_probs)
        return super().__init__(
            mixing_distribution=mixing_distribution,
            component_distributions=component_dists,
        )

    @property
    def locs(self):
        # hotfix for vmapping tests, which cannot easily check non-array attributes
        return self.component_distributions[0].loc

    @property
    def covariance_matrices(self):
        return self.component_distributions[0].covariance_matrix


@vmap_over.register
def _vmap_over_general_2d_mixture(self: _General2DMixture, locs=None):
    component_distributions = [
        vmap_over(d, loc=locs) for d in self.component_distributions
    ]
    return vmap_over.dispatch(dist.MixtureGeneral)(
        self, _component_distributions=component_distributions
    )


class _ImproperWrapper(dist.ImproperUniform):
    def sample(self, key, sample_shape=()):
        transform = biject_to(self.support)
        prototype_value = jnp.zeros(self.event_shape)
        unconstrained_event_shape = jnp.shape(transform.inv(prototype_value))
        shape = sample_shape + self.batch_shape + unconstrained_event_shape
        unconstrained_samples = random.uniform(key, shape, minval=-2, maxval=2)
        return transform(unconstrained_samples)


class ZeroInflatedPoissonLogits(dist.discrete.ZeroInflatedLogits):
    arg_constraints = {"rate": constraints.positive, "gate_logits": constraints.real}
    pytree_data_fields = ("rate",)

    def __init__(self, rate, gate_logits, *, validate_args=None):
        self.rate = rate
        super().__init__(dist.Poisson(rate), gate_logits, validate_args=validate_args)


@vmap_over.register
def _vmap_over_zero_inflated_poisson_logits(
    self: ZeroInflatedPoissonLogits, rate=None, gate_logits=None
):
    dist_axes = vmap_over.dispatch(dist.discrete.ZeroInflatedLogits)(
        self,
        base_dist=vmap_over(self.base_dist, rate=rate),
        gate_logits=gate_logits,
        gate=gate_logits,
    )
    dist_axes.rate = rate
    return dist_axes


class SparsePoisson(dist.Poisson):
    def __init__(self, rate, *, validate_args=None):
        super().__init__(rate, is_sparse=True, validate_args=validate_args)


class FoldedNormal(dist.FoldedDistribution):
    arg_constraints = {"loc": constraints.real, "scale": constraints.positive}

    def __init__(self, loc, scale, validate_args=None):
        self.loc = loc
        self.scale = scale
        super().__init__(dist.Normal(loc, scale), validate_args=validate_args)


@vmap_over.register
def _vmap_over_folded_normal(self: "FoldedNormal", loc=None, scale=None):
    d = vmap_over.dispatch(dist.FoldedDistribution)(
        self, base_dist=vmap_over(self.base_dist, loc=loc, scale=scale)
    )
    d.loc = loc
    d.scale = scale
    return d


class _SparseCAR(dist.CAR):
    reparametrized_params = ["loc", "correlation", "conditional_precision"]

    def __init__(
        self,
        loc,
        correlation,
        conditional_precision,
        adj_matrix,
        *,
        is_sparse=True,
        validate_args=None,
    ):
        super().__init__(
            loc,
            correlation,
            conditional_precision,
            adj_matrix,
            is_sparse=True,
            validate_args=validate_args,
        )


_DIST_MAP = {
    dist.AsymmetricLaplace: lambda loc, scale, asymmetry: osp.laplace_asymmetric(
        asymmetry, loc=loc, scale=scale
    ),
    dist.BernoulliProbs: lambda probs: osp.bernoulli(p=probs),
    dist.BernoulliLogits: lambda logits: osp.bernoulli(p=_to_probs_bernoulli(logits)),
    dist.Beta: lambda con1, con0: osp.beta(con1, con0),
    dist.BetaProportion: lambda mu, kappa: osp.beta(mu * kappa, (1 - mu) * kappa),
    dist.BinomialProbs: lambda probs, total_count: osp.binom(n=total_count, p=probs),
    dist.BinomialLogits: lambda logits, total_count: osp.binom(
        n=total_count, p=_to_probs_bernoulli(logits)
    ),
    dist.Cauchy: lambda loc, scale: osp.cauchy(loc=loc, scale=scale),
    dist.Chi2: lambda df: osp.chi2(df),
    dist.CirculantNormal: _circulant_to_scipy,
    dist.Dirichlet: lambda conc: osp.dirichlet(conc),
    dist.DiscreteUniform: lambda low, high: osp.randint(low, high + 1),
    dist.Exponential: lambda rate: osp.expon(scale=jnp.reciprocal(rate)),
    dist.Gamma: lambda conc, rate: osp.gamma(conc, scale=1.0 / rate),
    dist.GeometricProbs: lambda probs: osp.geom(p=probs, loc=-1),
    dist.GeometricLogits: lambda logits: osp.geom(
        p=_to_probs_bernoulli(logits), loc=-1
    ),
    dist.Gumbel: lambda loc, scale: osp.gumbel_r(loc=loc, scale=scale),
    dist.HalfCauchy: lambda scale: osp.halfcauchy(scale=scale),
    dist.HalfNormal: lambda scale: osp.halfnorm(scale=scale),
    dist.InverseGamma: lambda conc, rate: osp.invgamma(conc, scale=rate),
    dist.Laplace: lambda loc, scale: osp.laplace(loc=loc, scale=scale),
    dist.LogNormal: lambda loc, scale: osp.lognorm(s=scale, scale=jnp.exp(loc)),
    dist.LogUniform: lambda a, b: osp.loguniform(a, b),
    dist.MultinomialProbs: lambda probs, total_count: osp.multinomial(
        n=total_count, p=probs
    ),
    dist.MultinomialLogits: lambda logits, total_count: osp.multinomial(
        n=total_count, p=_to_probs_multinom(logits)
    ),
    dist.MultivariateNormal: _mvn_to_scipy,
    dist.MultivariateStudentT: _multivariate_t_to_scipy,
    dist.LowRankMultivariateNormal: _lowrank_mvn_to_scipy,
    dist.Normal: lambda loc, scale: osp.norm(loc=loc, scale=scale),
    dist.Pareto: lambda scale, alpha: osp.pareto(alpha, scale=scale),
    dist.Poisson: lambda rate: osp.poisson(rate),
    dist.StudentT: lambda df, loc, scale: osp.t(df=df, loc=loc, scale=scale),
    dist.Uniform: lambda a, b: osp.uniform(a, b - a),
    dist.Logistic: lambda loc, scale: osp.logistic(loc=loc, scale=scale),
    dist.VonMises: lambda loc, conc: osp.vonmises(
        loc=np.array(loc, dtype=np.float64), kappa=np.array(conc, dtype=np.float64)
    ),
    dist.Weibull: lambda scale, conc: osp.weibull_min(
        c=conc,
        scale=scale,
    ),
    dist.Wishart: _wishart_to_scipy,
    dist.InverseWishart: _invwishart_to_scipy,
    _TruncatedNormal: _truncnorm_to_scipy,
    dist.Levy: lambda loc, scale: osp.levy(loc=loc, scale=scale),
}


def get_sp_dist(jax_dist):
    classes = jax_dist.mro() if isinstance(jax_dist, type) else [jax_dist]
    for cls in classes:
        if cls in _DIST_MAP:
            return _DIST_MAP[cls]


CONTINUOUS = [
    T(dist.AsymmetricLaplace, 1.0, 0.5, 1.0),
    T(dist.AsymmetricLaplace, np.array([1.0, 2.0]), 2.0, 2.0),
    T(dist.AsymmetricLaplace, np.array([[1.0], [2.0]]), 2.0, np.array([3.0, 5.0])),
    T(dist.AsymmetricLaplaceQuantile, 0.0, 1.0, 0.5),
    T(dist.AsymmetricLaplaceQuantile, np.array([1.0, 2.0]), 2.0, 0.7),
    T(
        dist.AsymmetricLaplaceQuantile,
        np.array([[1.0], [2.0]]),
        2.0,
        np.array([0.2, 0.8]),
    ),
    T(dist.Beta, 0.2, 1.1),
    T(dist.Beta, 1.0, np.array([2.0, 2.0])),
    T(dist.Beta, 1.0, np.array([[1.0, 1.0], [2.0, 2.0]])),
    T(dist.BetaProportion, 0.2, 10.0),
    T(dist.BetaProportion, 0.51, np.array([2.0, 1.0])),
    T(dist.BetaProportion, 0.5, np.array([[4.0, 4.0], [2.0, 2.0]])),
    T(dist.Chi2, 2.0),
    T(dist.Chi2, np.array([0.3, 1.3])),
    T(dist.Cauchy, 0.0, 1.0),
    T(dist.Cauchy, 0.0, np.array([1.0, 2.0])),
    T(dist.Cauchy, np.array([0.0, 1.0]), np.array([[1.0], [2.0]])),
    T(_RightCensoredWeibull, 1.0, 1.0, 0),
    T(_RightCensoredWeibull, 1.0, 1.0, 1),
    T(_LeftCensoredHalfNormal, 1.0, 0),
    T(_LeftCensoredHalfNormal, 1.0, 1),
    T(_LeftCensoredNormal, 0.0, 1.0, 0),
    T(_LeftCensoredNormal, 0.0, 1.0, 1),
    T(_RightCensoredNormal, 0.0, 1.0, 0),
    T(_RightCensoredNormal, 0.0, 1.0, 1),
    T(_IntervalCensoredNormal, 0.0, 1.0, 0, 0),
    T(_IntervalCensoredNormal, 0.0, 1.0, 0, 1),
    T(_IntervalCensoredNormal, 0.0, 1.0, 1, 0),
    T(_IntervalCensoredNormal, 0.0, 1.0, 1, 1),
    T(_IntervalCensoredWeibull, 1.0, 1.0, 0, 1),
    T(_IntervalCensoredWeibull, 1.0, 1.0, 1, 0),
    T(_IntervalCensoredWeibull, 1.0, 1.0, 1, 1),
    T(_IntervalCensoredWeibull, 1.0, 1.0, 1, 1),
    T(dist.CirculantNormal, np.zeros((3, 4)), np.array([0.9, 0.2, 0.1, 0.2]), None),
    T(
        dist.CirculantNormal,
        np.zeros((7, 5)),
        np.array([0.9, 0.2, 0.1, 0.1, 0.2]),
        None,
    ),
    T(
        dist.CirculantNormal,
        np.zeros(4),
        np.broadcast_to(np.array([0.9, 0.2, 0.1, 0.2]), (3, 4)),
        None,
    ),
    T(
        dist.CirculantNormal,
        np.zeros(5),
        np.broadcast_to(np.array([0.9, 0.2, 0.1, 0.1, 0.2]), (7, 5)),
        None,
    ),
    T(dist.Dirichlet, np.array([1.7])),
    T(dist.Dirichlet, np.array([0.2, 1.1])),
    T(dist.Dirichlet, np.array([[0.2, 1.1], [2.0, 2.0]])),
    T(
        dist.EulerMaruyama,
        np.array([0.0, 0.1, 0.2]),
        sde_fn1,
        dist.Normal(0.1, 1.0),
    ),
    T(
        dist.EulerMaruyama,
        np.array([0.0, 0.1, 0.2]),
        sde_fn2,
        dist.Normal(np.array([0.0, 1.0]), 1e-3).to_event(1),
    ),
    T(
        dist.EulerMaruyama,
        np.array([[0.0, 0.1, 0.2], [10.0, 10.1, 10.2]]),
        sde_fn2,
        dist.Normal(np.array([0.0, 1.0]), 1e-3).to_event(1),
    ),
    T(
        dist.EulerMaruyama,
        np.array([[0.0, 0.1, 0.2], [10.0, 10.1, 10.2]]),
        sde_fn2,
        dist.Normal(np.array([[0.0, 1.0], [2.0, 3.0]]), 1e-2).to_event(1),
    ),
    T(dist.Exponential, 2.0),
    T(dist.Exponential, np.array([4.0, 2.0])),
    T(dist.Gamma, np.array([1.7]), np.array([[2.0], [3.0]])),
    T(dist.Gamma, np.array([0.5, 1.3]), np.array([[1.0], [3.0]])),
    T(dist.GaussianRandomWalk, 0.1, 10),
    T(dist.GaussianRandomWalk, np.array([0.1, 0.3, 0.25]), 10),
    T(
        dist.GaussianStateSpace,
        10,
        np.array([[0.8, 0.2], [-0.1, 1.1]]),
        np.array([[0.8, 0.2], [0.2, 0.7]]),
    ),
    T(
        dist.GaussianStateSpace,
        5,
        np.array([[0.8, 0.2], [-0.1, 1.1]]),
        np.array([0.1, 0.3, 0.25])[:, None, None] * np.array([[0.8, 0.2], [0.2, 0.7]]),
    ),
    pytest.param(
        *T(
            dist.GaussianCopulaBeta,
            np.array([7.0, 2.0]),
            np.array([4.0, 10.0]),
            np.array([[1.0, 0.75], [0.75, 1.0]]),
        ),
        marks=pytest.mark.xfail(reason="Beta copula does not work with jax 0.7.0"),
    ),
    pytest.param(
        *T(dist.GaussianCopulaBeta, 2.0, 1.5, np.eye(3)),
        marks=pytest.mark.xfail(reason="Beta copula does not work with jax 0.7.0"),
    ),
    pytest.param(
        *T(dist.GaussianCopulaBeta, 2.0, 1.5, np.full((5, 3, 3), np.eye(3))),
        marks=pytest.mark.xfail(reason="Beta copula does not work with jax 0.7.0"),
    ),
    T(dist.Gompertz, np.array([1.7]), np.array([[2.0], [3.0]])),
    T(dist.Gompertz, np.array([0.5, 1.3]), np.array([[1.0], [3.0]])),
    T(dist.Gumbel, 0.0, 1.0),
    T(dist.Gumbel, 0.5, 2.0),
    T(dist.Gumbel, np.array([0.0, 0.5]), np.array([1.0, 2.0])),
    T(FoldedNormal, 2.0, 4.0),
    T(FoldedNormal, np.array([2.0, 50.0]), np.array([4.0, 100.0])),
    T(dist.HalfCauchy, 1.0),
    T(dist.HalfCauchy, np.array([1.0, 2.0])),
    T(dist.HalfNormal, 1.0),
    T(dist.HalfNormal, np.array([1.0, 2.0])),
    T(_ImproperWrapper, constraints.positive, (), (3,)),
    T(dist.InverseGamma, np.array([3.1]), np.array([[2.0], [3.0]])),
    T(dist.InverseGamma, np.array([1.7]), np.array([[2.0], [3.0]])),
    T(dist.InverseGamma, np.array([0.5, 1.3]), np.array([[1.0], [3.0]])),
    T(dist.Kumaraswamy, 10.0, np.array([2.0, 3.0])),
    T(dist.Kumaraswamy, np.array([1.7]), np.array([[2.0], [3.0]])),
    T(dist.Kumaraswamy, 0.6, 0.5),
    T(dist.Laplace, 0.0, 1.0),
    T(dist.Laplace, 0.5, np.array([1.0, 2.5])),
    T(dist.Laplace, np.array([1.0, -0.5]), np.array([2.3, 3.0])),
    T(dist.LKJ, 2, 0.5, "onion"),
    T(dist.LKJ, 5, np.array([0.5, 1.0, 2.0]), "cvine"),
    T(dist.LKJCholesky, 2, 0.5, "onion"),
    T(dist.LKJCholesky, 2, 0.5, "cvine"),
    T(dist.LKJCholesky, 5, np.array([0.5, 1.0, 2.0]), "onion"),
    pytest.param(
        *T(dist.LKJCholesky, 5, np.array([0.5, 1.0, 2.0]), "cvine"),
        marks=pytest.mark.skipif("CI" in os.environ, reason="reduce time for CI"),
    ),
    pytest.param(
        *T(dist.LKJCholesky, 3, np.array([[3.0, 0.6], [0.2, 5.0]]), "onion"),
        marks=pytest.mark.skipif("CI" in os.environ, reason="reduce time for CI"),
    ),
    T(dist.LKJCholesky, 3, np.array([[3.0, 0.6], [0.2, 5.0]]), "cvine"),
    T(dist.Logistic, 0.0, 1.0),
    T(dist.Logistic, 1.0, np.array([1.0, 2.0])),
    T(dist.Logistic, np.array([0.0, 1.0]), np.array([[1.0], [2.0]])),
    T(dist.LogNormal, 1.0, 0.2),
    T(dist.LogNormal, -1.0, np.array([0.5, 1.3])),
    T(dist.LogNormal, np.array([0.5, -0.7]), np.array([[0.1, 0.4], [0.5, 0.1]])),
    T(dist.LogUniform, 1.0, 2.0),
    T(dist.LogUniform, 1.0, np.array([2.0, 3.0])),
    T(dist.LogUniform, np.array([1.0, 2.0]), np.array([[3.0], [4.0]])),
    T(
        dist.MatrixNormal,
        1.0 * np.arange(6).reshape(3, 2),
        np.array([[1.0, 0, 0], [0.3, 0.36, 0], [0.4, 0.49, 4]]),
        np.array([[1.0, 0], [0.4, 1]]),
    ),
    T(
        dist.MatrixNormal,
        1.0 * np.arange(12).reshape((2, 3, 2)),
        np.array([[1.0, 0, 0], [0.3, 0.36, 0], [0.4, 0.49, 4]]) * np.ones((2, 3, 3)),
        np.array([[1.0, 0], [0.4, 0.5]]) * np.ones((2, 2, 2)),
    ),
    T(
        dist.MatrixNormal,
        1.0 * np.arange(36).reshape((2, 3, 3, 2)),
        np.identity(3),
        np.identity(2),
    ),
    T(dist.MultivariateNormal, 0.0, np.array([[1.0, 0.5], [0.5, 1.0]]), None, None),
    T(
        dist.MultivariateNormal,
        np.array([1.0, 3.0]),
        None,
        np.array([[1.0, 0.5], [0.5, 1.0]]),
        None,
    ),
    T(
        dist.MultivariateNormal,
        np.array([1.0, 3.0]),
        None,
        np.array([[[1.0, 0.5], [0.5, 1.0]]]),
        None,
    ),
    T(
        dist.MultivariateNormal,
        np.array([2.0]),
        None,
        None,
        np.array([[1.0, 0.0], [0.5, 1.0]]),
    ),
    T(
        dist.MultivariateNormal,
        np.arange(6, dtype=np.float32).reshape((3, 2)),
        None,
        None,
        np.array([[1.0, 0.0], [0.0, 1.0]]),
    ),
    T(
        dist.MultivariateNormal,
        0.0,
        None,
        np.broadcast_to(np.identity(3), (2, 3, 3)),
        None,
    ),
    T(
        dist.CAR,
        1.2,
        np.array([-0.2, 0.3]),
        0.1,
        np.array(
            [
                [0.0, 1.0, 1.0, 0.0],
                [1.0, 0.0, 0.0, 1.0],
                [1.0, 0.0, 0.0, 1.0],
                [0.0, 1.0, 1.0, 0.0],
            ]
        ),
    ),
    T(
        dist.CAR,
        np.array([0.0, 1.0, 3.0, 4.0]),
        0.1,
        np.array([0.3, 0.7]),
        np.array(
            [
                [0.0, 1.0, 1.0, 0.0],
                [1.0, 0.0, 0.0, 1.0],
                [1.0, 0.0, 0.0, 1.0],
                [0.0, 1.0, 1.0, 0.0],
            ]
        ),
    ),
    T(
        _SparseCAR,
        np.array([[0.0, 1.0, 3.0, 4.0], [2.0, -1.0, -3.0, 2.0]]),
        0.0,
        0.1,
        np.array(
            [
                [0.0, 1.0, 1.0, 0.0],
                [1.0, 0.0, 0.0, 1.0],
                [1.0, 0.0, 0.0, 1.0],
                [0.0, 1.0, 1.0, 0.0],
            ]
        ),
    ),
    T(
        dist.MultivariateStudentT,
        15.0,
        0.0,
        np.array([[1.0, 0.0], [0.5, 1.0]]),
    ),
    T(
        dist.MultivariateStudentT,
        15.0,
        np.array([1.0, 3.0]),
        np.array([[1.0, 0.0], [0.5, 1.0]]),
    ),
    T(
        dist.MultivariateStudentT,
        15.0,
        np.array([1.0, 3.0]),
        np.array([[[1.0, 0.0], [0.5, 1.0]]]),
    ),
    T(
        dist.MultivariateStudentT,
        15.0,
        np.array([3.0]),
        np.array([[1.0, 0.0], [0.5, 1.0]]),
    ),
    T(
        dist.MultivariateStudentT,
        15.0,
        np.arange(6, dtype=np.float32).reshape((3, 2)),
        np.array([[1.0, 0.0], [0.5, 1.0]]),
    ),
    T(
        dist.MultivariateStudentT,
        15.0,
        np.ones(3),
        np.broadcast_to(np.identity(3), (2, 3, 3)),
    ),
    T(
        dist.MultivariateStudentT,
        np.array(7.0),
        np.array([1.0, 3.0]),
        np.array([[1.0, 0.0], [0.5, 1.0]]),
    ),
    T(
        dist.MultivariateStudentT,
        np.arange(20, 22, dtype=jnp.float32),
        np.ones(3),
        np.broadcast_to(np.identity(3), (2, 3, 3)),
    ),
    T(
        dist.MultivariateStudentT,
        np.arange(20, 26, dtype=jnp.float32).reshape((3, 2)),
        np.ones(2),
        np.array([[1.0, 0.0], [0.5, 1.0]]),
    ),
    T(
        dist.LowRankMultivariateNormal,
        np.zeros(2),
        np.array([[1.0], [0.0]]),
        np.array([1.0, 1.0]),
    ),
    T(
        dist.LowRankMultivariateNormal,
        np.arange(6, dtype=np.float32).reshape((2, 3)),
        np.arange(6, dtype=np.float32).reshape((3, 2)),
        np.array([1.0, 2.0, 3.0]),
    ),
    T(dist.Normal, 0.0, 1.0),
    T(dist.Normal, 1.0, np.array([1.0, 2.0])),
    T(dist.Normal, np.array([0.0, 1.0]), np.array([[1.0], [2.0]])),
    T(dist.Pareto, 1.0, 2.0),
    T(dist.Pareto, np.array([1.0, 0.5]), np.array([0.3, 2.0])),
    T(dist.Pareto, np.array([[1.0], [3.0]]), np.array([1.0, 0.5])),
    T(dist.RelaxedBernoulliLogits, 2.0, -10.0),
    T(dist.RelaxedBernoulliLogits, np.array([1.0, 3.0]), np.array([3.0, 8.0])),
    T(dist.SoftLaplace, 1.0, 1.0),
    T(dist.SoftLaplace, np.array([-1.0, 50.0]), np.array([4.0, 100.0])),
    T(dist.StudentT, 1.0, 1.0, 0.5),
    T(dist.StudentT, 2.0, np.array([1.0, 2.0]), 2.0),
    T(dist.StudentT, np.array([3.0, 5.0]), np.array([[1.0], [2.0]]), 2.0),
    T(_TruncatedCauchy, 0.0, 1.0, -1.0, None),
    T(_TruncatedCauchy, 0.0, np.array([1.0, 2.0]), 1.0, None),
    T(
        _TruncatedCauchy,
        np.array([0.0, 1.0]),
        np.array([[1.0], [2.0]]),
        np.array([-2.0, 2.0]),
        None,
    ),
    T(_TruncatedCauchy, 0.0, 1.0, None, 1.0),
    T(_TruncatedCauchy, 0.0, 1.0, -1.0, 1.0),
    T(_TruncatedNormal, 0.0, 1.0, -1.0, None),
    T(_TruncatedNormal, -1.0, np.array([1.0, 2.0]), 1.0, None),
    T(
        _TruncatedNormal,
        np.array([0.0, 1.0]),
        np.array([[1.0], [2.0]]),
        np.array([-2.0, 2.0]),
        None,
    ),
    T(_TruncatedNormal, -1.0, 2.0, 1.0, 5.0),
    T(_TruncatedNormal, np.array([-1.0, 4.0]), 2.0, None, 5.0),
    T(_TruncatedNormal, -1.0, np.array([2.0, 3.0]), 1.0, None),
    T(_TruncatedNormal, -1.0, 2.0, np.array([-6.0, 4.0]), np.array([-4.0, 6.0])),
    T(
        _TruncatedNormal,
        np.array([0.0, 1.0]),
        np.array([[1.0], [2.0]]),
        None,
        np.array([-2.0, 2.0]),
    ),
    T(dist.TwoSidedTruncatedDistribution, dist.Laplace(0.0, 1.0), -2.0, 3.0),
    T(dist.Uniform, 0.0, 2.0),
    T(dist.Uniform, 1.0, np.array([2.0, 3.0])),
    T(dist.Uniform, np.array([0.0, 0.0]), np.array([[2.0], [3.0]])),
    T(dist.Weibull, 0.2, 1.1),
    T(dist.Weibull, 2.8, np.array([2.0, 2.0])),
    T(dist.Weibull, 1.8, np.array([[1.0, 1.0], [2.0, 2.0]])),
    T(dist.Wishart, 3, 2 * np.eye(2) + 0.1, None, None),
    T(
        dist.Wishart,
        3.0,
        None,
        np.array([[1.0, 0.5], [0.5, 1.0]]),
        None,
    ),
    T(
        dist.Wishart,
        np.array([4.0, 5.0]),
        None,
        np.array([[[1.0, 0.5], [0.5, 1.0]]]),
        None,
    ),
    T(
        dist.Wishart,
        np.array([3.0]),
        None,
        None,
        np.array([[1.0, 0.0], [0.5, 1.0]]),
    ),
    T(
        dist.Wishart,
        np.arange(3, 9, dtype=np.float32).reshape((3, 2)),
        None,
        None,
        np.array([[1.0, 0.0], [0.0, 1.0]]),
    ),
    T(
        dist.Wishart,
        9.0,
        None,
        np.broadcast_to(np.identity(3), (2, 3, 3)),
        None,
    ),
    T(dist.WishartCholesky, 3, 2 * np.eye(2) + 0.1, None, None),
    T(
        dist.WishartCholesky,
        3.0,
        None,
        np.array([[1.0, 0.5], [0.5, 1.0]]),
        None,
    ),
    T(
        dist.WishartCholesky,
        np.array([4.0, 5.0]),
        None,
        np.array([[[1.0, 0.5], [0.5, 1.0]]]),
        None,
    ),
    T(
        dist.WishartCholesky,
        np.array([3.0]),
        None,
        None,
        np.array([[1.0, 0.0], [0.5, 1.0]]),
    ),
    T(
        dist.WishartCholesky,
        np.arange(3, 9, dtype=np.float32).reshape((3, 2)),
        None,
        None,
        np.array([[1.0, 0.0], [0.0, 1.0]]),
    ),
    T(
        dist.WishartCholesky,
        9.0,
        None,
        np.broadcast_to(np.identity(3), (2, 3, 3)),
        None,
    ),
    # InverseWishart tests: concentration must be > p - 1 (for p=2, need conc > 1)
    # Using concentration > p + 1 to ensure mean exists
    T(dist.InverseWishart, 5, 2 * np.eye(2) + 0.1, None, None),
    T(
        dist.InverseWishart,
        5.0,
        None,
        np.array([[1.0, 0.5], [0.5, 1.0]]),
        None,
    ),
    T(
        dist.InverseWishart,
        np.array([6.0, 7.0]),
        None,
        np.array([[[1.0, 0.5], [0.5, 1.0]]]),
        None,
    ),
    T(
        dist.InverseWishart,
        np.array([5.0]),
        None,
        None,
        np.array([[1.0, 0.0], [0.5, 1.0]]),
    ),
    T(
        dist.InverseWishart,
        np.arange(5, 11, dtype=np.float32).reshape((3, 2)),
        None,
        None,
        np.array([[1.0, 0.0], [0.0, 1.0]]),
    ),
    T(
        dist.InverseWishart,
        9.0,
        None,
        np.broadcast_to(np.identity(3), (2, 3, 3)),
        None,
    ),
    T(dist.InverseWishartCholesky, 5, 2 * np.eye(2) + 0.1, None, None),
    T(
        dist.InverseWishartCholesky,
        5.0,
        None,
        np.array([[1.0, 0.5], [0.5, 1.0]]),
        None,
    ),
    T(
        dist.InverseWishartCholesky,
        np.array([6.0, 7.0]),
        None,
        np.array([[[1.0, 0.5], [0.5, 1.0]]]),
        None,
    ),
    T(
        dist.InverseWishartCholesky,
        np.array([5.0]),
        None,
        None,
        np.array([[1.0, 0.0], [0.5, 1.0]]),
    ),
    T(
        dist.InverseWishartCholesky,
        np.arange(5, 11, dtype=np.float32).reshape((3, 2)),
        None,
        None,
        np.array([[1.0, 0.0], [0.0, 1.0]]),
    ),
    T(
        dist.InverseWishartCholesky,
        9.0,
        None,
        np.broadcast_to(np.identity(3), (2, 3, 3)),
        None,
    ),
    T(dist.ZeroSumNormal, 1.0, (5,)),
    T(dist.ZeroSumNormal, np.array([2.0]), (5,)),
    T(dist.ZeroSumNormal, 1.0, (4, 5)),
    T(
        _GaussianMixture,
        np.ones(3) / 3.0,
        np.array([0.0, 7.7, 2.1]),
        np.array([4.2, 7.7, 2.1]),
    ),
    T(
        _Gaussian2DMixture,
        np.array([0.2, 0.5, 0.3]),
        np.array([[-1.2, 1.5], [2.0, 2.0], [-1, 4.0]]),  # Mean
        np.array(
            [
                [
                    [0.1, -0.2],
                    [-0.2, 1.0],
                ],
                [
                    [0.75, 0.0],
                    [0.0, 0.75],
                ],
                [
                    [1.0, 0.5],
                    [0.5, 0.27],
                ],
            ]
        ),  # Covariance
    ),
    T(
        _GeneralMixture,
        np.array([0.2, 0.3, 0.5]),
        np.array([0.0, 7.7, 2.1]),
        np.array([4.2, 1.7, 2.1]),
    ),
    T(
        _General2DMixture,
        np.array([0.2, 0.5, 0.3]),
        np.array([[-1.2, 1.5], [2.0, 2.0], [-1, 4.0]]),  # Mean
        np.array(
            [
                [
                    [0.1, -0.2],
                    [-0.2, 1.0],
                ],
                [
                    [0.75, 0.0],
                    [0.0, 0.75],
                ],
                [
                    [1.0, 0.5],
                    [0.5, 0.27],
                ],
            ]
        ),  # Covariance
    ),
    T(dist.LowerTruncatedPowerLaw, -np.pi, np.array([2.0, 5.0, 10.0, 50.0])),
    T(dist.DoublyTruncatedPowerLaw, -1.0, 1.0, 2.0),
    T(dist.DoublyTruncatedPowerLaw, np.pi, 5.0, 50.0),
    T(dist.DoublyTruncatedPowerLaw, -1.0, 5.0, 50.0),
    T(dist.DoublyTruncatedPowerLaw, np.pi, 1.0, 2.0),
    T(dist.Levy, 0.0, 1.0),
    T(dist.Levy, 0.0, np.array([1.0, 2.0, 10.0])),
    T(dist.Levy, np.array([1.0, 2.0, 10.0]), np.pi),
    T(dist.Dagum, 1.0, 1.0, 1.0),
    T(dist.Dagum, 3.0, 4.0, 5.0),
    T(dist.Dagum, 2.0, np.array([1.0, 2.0, 10.0]), 4.0),
    T(dist.Dagum, 2.0, 3.0, np.array([0.5, 2.0, 1.0])),
    T(dist.Dagum, np.array([5.0, 2.0, 10.0]), 3.0, 5.0),
]

DIRECTIONAL = [
    T(dist.VonMises, 2.0, 10.0),
    T(dist.VonMises, 2.0, np.array([150.0, 10.0])),
    T(dist.VonMises, np.array([1 / 3 * np.pi, -1.0]), np.array([20.0, 30.0])),
    pytest.param(
        *T(
            dist.SineBivariateVonMises,
            0.0,
            0.0,
            5.0,
            6.0,
            2.0,
        ),
        marks=pytest.mark.skipif("CI" in os.environ, reason="reduce time for CI"),
    ),
    T(
        dist.SineBivariateVonMises,
        3.003,
        -1.343,
        5.0,
        6.0,
        2.0,
    ),
    pytest.param(
        *T(
            dist.SineBivariateVonMises,
            -1.232,
            -1.3430,
            3.4,
            2.0,
            1.0,
        ),
        marks=pytest.mark.skipif("CI" in os.environ, reason="reduce time for CI"),
    ),
    pytest.param(
        *T(
            dist.SineBivariateVonMises,
            np.array([math.pi - 0.2, 1.0]),
            np.array([0.0, 1.0]),
            np.array([5.0, 5.0]),
            np.array([7.0, 0.5]),
            None,
            np.array([0.5, 0.1]),
        ),
        marks=pytest.mark.skipif("CI" in os.environ, reason="reduce time for CI"),
    ),
    T(dist.ProjectedNormal, np.array([0.0, 0.0])),
    T(dist.ProjectedNormal, np.array([[2.0, 3.0]])),
    T(dist.ProjectedNormal, np.array([0.0, 0.0, 0.0])),
    T(dist.ProjectedNormal, np.array([[-1.0, 2.0, 3.0]])),
    T(SineSkewedUniform, np.array([-math.pi / 4, 0.1])),
    T(SineSkewedVonMises, np.array([0.342355])),
    T(SineSkewedVonMisesBatched, np.array([[0.342355, -0.0001], [0.91, 0.09]])),
]

DISCRETE = [
    T(dist.BetaBinomial, 2.0, 5.0, 10),
    T(
        dist.BetaBinomial,
        np.array([2.0, 4.0]),
        np.array([5.0, 3.0]),
        np.array([10, 12]),
    ),
    T(dist.BetaNegativeBinomial, 2.0, 5.0, 10),
    T(
        dist.BetaNegativeBinomial,
        np.array([2.0, 4.0]),
        np.array([5.0, 3.0]),
        np.array([5, 10]),
    ),
    T(dist.BernoulliProbs, 0.2),
    T(dist.BernoulliProbs, np.array([0.2, 0.7])),
    T(dist.BernoulliLogits, np.array([-1.0, 3.0])),
    T(dist.BinomialProbs, np.array([0.2, 0.7]), np.array([10, 2])),
    T(dist.BinomialProbs, np.array([0.2, 0.7]), np.array([5, 8])),
    T(dist.BinomialLogits, np.array([-1.0, 3.0]), np.array([5, 8])),
    T(dist.CategoricalProbs, np.array([1.0])),
    T(dist.CategoricalProbs, np.array([0.1, 0.5, 0.4])),
    T(dist.CategoricalProbs, np.array([[0.1, 0.5, 0.4], [0.4, 0.4, 0.2]])),
    T(dist.CategoricalLogits, np.array([-5.0])),
    T(dist.CategoricalLogits, np.array([1.0, 2.0, -2.0])),
    T(dist.CategoricalLogits, np.array([[-1, 2.0, 3.0], [3.0, -4.0, -2.0]])),
    T(dist.Delta, 1),
    T(dist.Delta, np.array([0.0, 2.0])),
    T(dist.Delta, np.array([0.0, 2.0]), np.array([-2.0, -4.0])),
    T(dist.DirichletMultinomial, np.array([1.0, 2.0, 3.9]), 10),
    T(dist.DirichletMultinomial, np.array([0.2, 0.7, 1.1]), np.array([5, 5])),
    T(dist.GammaPoisson, 2.0, 2.0),
    T(dist.GammaPoisson, np.array([6.0, 2]), np.array([2.0, 8.0])),
    T(dist.GeometricProbs, 0.2),
    T(dist.GeometricProbs, np.array([0.2, 0.7])),
    T(dist.GeometricLogits, np.array([-1.0, 3.0])),
    T(_LeftCensoredPoisson, 1.0, 0),
    T(_LeftCensoredPoisson, 1.0, 1),
    T(_RightCensoredPoisson, 1.0, 0),
    T(_RightCensoredPoisson, 1.0, 1),
    T(_IntervalCensoredPoisson, 3, 0, 0),
    T(_IntervalCensoredPoisson, 3, 0, 1),
    T(_IntervalCensoredPoisson, 3, 1, 0),
    T(_IntervalCensoredPoisson, 3, 1, 0),
    T(dist.MultinomialProbs, np.array([0.2, 0.7, 0.1]), 10),
    T(dist.MultinomialProbs, np.array([0.2, 0.7, 0.1]), np.array([5, 8])),
    T(dist.MultinomialLogits, np.array([-1.0, 3.0]), np.array([[5], [8]])),
    T(dist.NegativeBinomialProbs, 10, 0.2),
    T(dist.NegativeBinomialProbs, 10, np.array([0.2, 0.6])),
    T(dist.NegativeBinomialProbs, np.array([4.2, 10.7, 2.1]), 0.2),
    T(
        dist.NegativeBinomialProbs,
        np.array([4.2, 10.7, 2.1]),
        np.array([0.2, 0.6, 0.5]),
    ),
    T(dist.NegativeBinomialLogits, 10, -2.1),
    T(dist.NegativeBinomialLogits, 10, np.array([-5.2, 2.1])),
    T(dist.NegativeBinomialLogits, np.array([4.2, 10.7, 2.1]), -5.2),
    T(
        dist.NegativeBinomialLogits,
        np.array([4.2, 7.7, 2.1]),
        np.array([4.2, 0.7, 2.1]),
    ),
    T(dist.NegativeBinomial2, 0.3, 10),
    T(dist.NegativeBinomial2, np.array([10.2, 7, 31]), 10),
    T(dist.NegativeBinomial2, np.array([10.2, 7, 31]), np.array([10.2, 20.7, 2.1])),
    T(dist.OrderedLogistic, -2, np.array([-10.0, 4.0, 9.0])),
    T(dist.OrderedLogistic, np.array([-4, 3, 4, 5]), np.array([-1.5])),
    T(dist.DiscreteUniform, -2, np.array([-1.0, 4.0, 9.0])),
    T(dist.DiscreteUniform, np.array([-4, 3, 4, 5]), np.array([6])),
    T(dist.Poisson, 2.0),
    T(dist.Poisson, np.array([2.0, 3.0, 5.0])),
    T(SparsePoisson, 2.0),
    T(SparsePoisson, np.array([2.0, 3.0, 5.0])),
    T(SparsePoisson, 2),
    T(dist.ZeroInflatedPoisson, 0.6, 2.0),
    T(dist.ZeroInflatedPoisson, np.array([0.2, 0.7, 0.3]), np.array([2.0, 3.0, 5.0])),
    T(ZeroInflatedPoissonLogits, 2.0, 3.0),
    T(
        ZeroInflatedPoissonLogits,
        np.array([0.2, 4.0, 0.3]),
        np.array([2.0, -3.0, 5.0]),
    ),
]

BASE = [
    T(lambda *args: dist.Normal(*args).to_event(2), np.arange(24).reshape(3, 4, 2)),
    T(lambda *args: dist.Normal(*args).expand((3, 4, 7)), np.arange(7)),
    T(
        lambda *args: dist.Normal(*args).to_event(2).expand((7, 3)),
        np.arange(24).reshape(3, 4, 2),
    ),
]


def _is_batched_multivariate(jax_dist):
    return len(jax_dist.event_shape) > 0 and len(jax_dist.batch_shape) > 0


def gen_values_within_bounds(constraint, size, key=None):
    key = random.PRNGKey(11) if key is None else key
    eps = 1e-6

    if constraint is constraints.boolean:
        return random.bernoulli(key, shape=size)
    elif isinstance(constraint, constraints.greater_than):
        return jnp.exp(random.normal(key, size)) + constraint.lower_bound + eps
    elif isinstance(constraint, constraints.less_than):
        return constraint.upper_bound - jnp.exp(random.normal(key, size)) - eps
    elif isinstance(constraint, constraints.integer_interval):
        lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
        upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
        return random.randint(key, size, lower_bound, upper_bound + 1)
    elif isinstance(constraint, constraints.integer_greater_than):
        return constraint.lower_bound + random.poisson(key, np.array(5), shape=size)
    elif isinstance(constraint, constraints.interval):
        lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
        upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key, size, minval=lower_bound, maxval=upper_bound)
    elif constraint in (constraints.real, constraints.real_vector):
        return random.normal(key, size)
    elif constraint is constraints.simplex:
        return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1])
    elif isinstance(constraint, constraints.multinomial):
        n = size[-1]
        return multinomial(
            key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]
        )
    elif constraint is constraints.corr_cholesky:
        return signed_stick_breaking_tril(
            random.uniform(
                key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1
            )
        )
    elif constraint is constraints.corr_matrix:
        cholesky = signed_stick_breaking_tril(
            random.uniform(
                key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1
            )
        )
        return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1))
    elif constraint is constraints.lower_cholesky:
        return jnp.tril(random.uniform(key, size))
    elif constraint is constraints.positive_definite:
        x = random.normal(key, size)
        return jnp.matmul(x, jnp.swapaxes(x, -2, -1))
    elif constraint is constraints.ordered_vector:
        x = jnp.cumsum(random.exponential(key, size), -1)
        return x - random.normal(key, size[:-1] + (1,))
    elif isinstance(constraint, constraints.independent):
        return gen_values_within_bounds(constraint.base_constraint, size, key)
    elif constraint is constraints.sphere:
        x = random.normal(key, size)
        return x / jnp.linalg.norm(x, axis=-1)
    elif constraint is constraints.l1_ball:
        key1, key2 = random.split(key)
        sign = random.bernoulli(key1)
        bounds = [0, (-1) ** sign * 0.5]
        return random.uniform(key, size, float, *sorted(bounds))
    elif isinstance(constraint, constraints.zero_sum):
        x = random.normal(key, size)
        zero_sum_axes = tuple(i for i in range(-constraint.event_dim, 0))
        for axis in zero_sum_axes:
            x -= x.mean(axis)
        return x
    elif constraint is constraints.positive_definite_circulant_vector:
        return jnp.fft.irfft(random.gamma(key, 10, size) / 10, n=size[-1])
    else:
        raise NotImplementedError("{} not implemented.".format(constraint))


def gen_values_outside_bounds(constraint, size, key=None):
    key = random.PRNGKey(11) if key is None else key
    if constraint is constraints.boolean:
        return random.bernoulli(key, shape=size) - 2
    elif isinstance(constraint, constraints.greater_than):
        return constraint.lower_bound - jnp.exp(random.normal(key, size))
    elif isinstance(constraint, constraints.less_than):
        return constraint.upper_bound + jnp.exp(random.normal(key, size))
    elif isinstance(constraint, constraints.integer_interval):
        lower_bound = jnp.broadcast_to(constraint.lower_bound, size)
        return random.randint(key, size, lower_bound - 1, lower_bound)
    elif isinstance(constraint, constraints.integer_greater_than):
        return constraint.lower_bound - random.poisson(key, np.array(5), shape=size)
    elif isinstance(constraint, constraints.interval):
        upper_bound = jnp.broadcast_to(constraint.upper_bound, size)
        return random.uniform(key, size, minval=upper_bound, maxval=upper_bound + 1.0)
    elif constraint in [constraints.real, constraints.real_vector]:
        return lax.full(size, np.nan)
    elif constraint is constraints.simplex:
        return osp.dirichlet.rvs(alpha=jnp.ones((size[-1],)), size=size[:-1]) + 1e-2
    elif isinstance(constraint, constraints.multinomial):
        n = size[-1]
        return (
            multinomial(
                key, p=jnp.ones((n,)) / n, n=constraint.upper_bound, shape=size[:-1]
            )
            + 1
        )
    elif constraint is constraints.corr_cholesky:
        return (
            signed_stick_breaking_tril(
                random.uniform(
                    key,
                    size[:-2] + (size[-1] * (size[-1] - 1) // 2,),
                    minval=-1,
                    maxval=1,
                )
            )
            + 1e-2
        )
    elif constraint is constraints.corr_matrix:
        cholesky = 1e-2 + signed_stick_breaking_tril(
            random.uniform(
                key, size[:-2] + (size[-1] * (size[-1] - 1) // 2,), minval=-1, maxval=1
            )
        )
        return jnp.matmul(cholesky, jnp.swapaxes(cholesky, -2, -1))
    elif constraint is constraints.lower_cholesky:
        return random.uniform(key, size)
    elif constraint is constraints.positive_definite:
        return random.normal(key, size)
    elif constraint is constraints.ordered_vector:
        x = jnp.cumsum(random.exponential(key, size), -1)
        return x[..., ::-1]
    elif isinstance(constraint, constraints.independent):
        return gen_values_outside_bounds(constraint.base_constraint, size, key)
    elif constraint is constraints.sphere:
        x = random.normal(key, size)
        x = x / jnp.linalg.norm(x, axis=-1, keepdims=True)
        return 2 * x
    elif constraint is constraints.l1_ball:
        key1, key2 = random.split(key)
        sign = random.bernoulli(key1)
        bounds = [(-1) ** sign * 1.1, (-1) ** sign * 2]
        return random.uniform(key, size, float, *sorted(bounds))
    elif isinstance(constraint, constraints.zero_sum):
        x = random.normal(key, size)
        return x
    elif constraint is constraints.positive_definite_circulant_vector:
        return random.normal(key, size)
    else:
        raise NotImplementedError("{} not implemented.".format(constraint))


@pytest.mark.parametrize(
    "jax_dist_cls, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
@pytest.mark.parametrize("prepend_shape", [(), (2,), (2, 3)])
def test_dist_shape(jax_dist_cls, sp_dist, params, prepend_shape):
    jax_dist = jax_dist_cls(*params)
    rng_key = random.PRNGKey(0)
    expected_shape = prepend_shape + jax_dist.batch_shape + jax_dist.event_shape
    samples = jax_dist.sample(key=rng_key, sample_shape=prepend_shape)
    if jax_dist_cls is not dist.Delta:
        assert isinstance(samples, jnp.ndarray)
    if isinstance(jax_dist, dist.IntervalCensoredDistribution):
        # interval censored distributions take interval (lo-hi) input but return univarite samples
        expected_shape = expected_shape[:-1]
    assert jnp.shape(samples) == expected_shape
    if (
        sp_dist
        and not _is_batched_multivariate(jax_dist)
        and not isinstance(jax_dist, dist.MultivariateStudentT)
    ):
        sp_dist = sp_dist(*params)
        size = prepend_shape + jax_dist.batch_shape
        # The scipy implementation of the Wishart distribution cannot handle an empty
        # tuple as the sample size so we replace it by `1` which generates a single
        # sample without any sample shape.
        if isinstance(jax_dist, dist.Wishart):
            size = size or 1
        sp_samples = sp_dist.rvs(size=size)
        assert jnp.shape(sp_samples) == expected_shape
    elif (
        sp_dist
        and not _is_batched_multivariate(jax_dist)
        and isinstance(jax_dist, dist.MultivariateStudentT)
    ):
        sp_dist = sp_dist(*params)
        size_ = prepend_shape + jax_dist.batch_shape
        size = (1) if size_ == () else size_
        try:
            sp_samples = sp_dist.rvs(size=size)
        except ValueError:
            pytest.skip("scipy multivariate t doesn't support size with > 1 element")
        assert jnp.shape(sp_samples) == expected_shape
    if isinstance(jax_dist, (dist.MultivariateNormal, dist.MultivariateStudentT)):
        assert jax_dist.covariance_matrix.ndim == len(jax_dist.batch_shape) + 2
        assert_allclose(
            jax_dist.precision_matrix,
            jnp.linalg.inv(jax_dist.covariance_matrix),
            rtol=1e-6,
        )


@pytest.mark.parametrize(
    "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
def test_infer_shapes(jax_dist, sp_dist, params):
    shapes = []
    for param in params:
        if param is None:
            shapes.append(None)
            continue
        shape = getattr(param, "shape", ())
        if callable(shape):
            shape = shape()
        shapes.append(shape)
    jax_dist = jax_dist(*params)
    try:
        expected_batch_shape, expected_event_shape = type(jax_dist).infer_shapes(
            *shapes
        )
    except NotImplementedError:
        pytest.skip(f"{type(jax_dist).__name__}.infer_shapes() is not implemented")
    assert jax_dist.batch_shape == expected_batch_shape
    assert jax_dist.event_shape == expected_event_shape


@pytest.mark.parametrize(
    "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
def test_has_rsample(jax_dist, sp_dist, params):
    jax_dist = jax_dist(*params)
    masked_dist = jax_dist.mask(False)
    indept_dist = jax_dist.expand_by([2]).to_event(1)
    transf_dist = dist.TransformedDistribution(jax_dist, biject_to(constraints.real))
    assert masked_dist.has_rsample == jax_dist.has_rsample
    assert indept_dist.has_rsample == jax_dist.has_rsample
    assert transf_dist.has_rsample == jax_dist.has_rsample

    if jax_dist.has_rsample:
        assert isinstance(jax_dist, dist.Delta) or not jax_dist.is_discrete
        if isinstance(jax_dist, dist.TransformedDistribution):
            assert jax_dist.base_dist.has_rsample
        else:
            assert set(jax_dist.arg_constraints) == set(jax_dist.reparametrized_params)
        jax_dist.rsample(random.PRNGKey(0))
        if isinstance(jax_dist, dist.Normal):
            masked_dist.rsample(random.PRNGKey(0))
            indept_dist.rsample(random.PRNGKey(0))
            transf_dist.rsample(random.PRNGKey(0))
    else:
        with pytest.raises(NotImplementedError):
            jax_dist.rsample(random.PRNGKey(0))
        if isinstance(jax_dist, dist.BernoulliProbs):
            with pytest.raises(NotImplementedError):
                masked_dist.rsample(random.PRNGKey(0))
            with pytest.raises(NotImplementedError):
                indept_dist.rsample(random.PRNGKey(0))
            with pytest.raises(NotImplementedError):
                transf_dist.rsample(random.PRNGKey(0))


@pytest.mark.parametrize(
    "jax_dist_cls, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
def test_args_attributes(jax_dist_cls, sp_dist, params):
    jax_dist = jax_dist_cls(*params)
    for constraint in jax_dist.arg_constraints.values():
        if jax_dist_cls != dist.Delta:
            constraint.event_dim
        constraint.is_discrete


@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)])
def test_unit(batch_shape):
    log_factor = random.normal(random.PRNGKey(0), batch_shape)
    d = dist.Unit(log_factor=log_factor)
    x = d.sample(random.PRNGKey(1))
    assert x.shape == batch_shape + (0,)
    assert (d.log_prob(x) == log_factor).all()


@pytest.mark.parametrize("jax_dist, sp_dist, params", CONTINUOUS)
def test_sample_gradient(jax_dist, sp_dist, params):
    # we have pathwise gradient for gamma sampler
    gamma_derived_params = {
        "Gamma": ["concentration"],
        "Beta": ["concentration1", "concentration0"],
        "BetaProportion": ["mean", "concentration"],
        "Chi2": ["df"],
        "Dirichlet": ["concentration"],
        "InverseGamma": ["concentration"],
        "LKJ": ["concentration"],
        "LKJCholesky": ["concentration"],
        "StudentT": ["df"],
    }.get(jax_dist.__name__, [])

    if (
        jax_dist in [dist.DoublyTruncatedPowerLaw, dist.Dagum]
        and jnp.result_type(float) == jnp.float32
    ):
        pytest.skip(f"{jax_dist.__name__} is tested with x64 only.")

    dist_args = [
        p
        for p in (
            inspect.getfullargspec(jax_dist.__init__)[0][1:]
            if inspect.isclass(jax_dist)
            # account the the case jax_dist is a function
            else inspect.getfullargspec(jax_dist)[0]
        )
    ]
    params_dict = dict(zip(dist_args[: len(params)], params))

    jax_class = type(jax_dist(**params_dict))
    reparametrized_params = [
        p for p in jax_class.reparametrized_params if p not in gamma_derived_params
    ]
    if not reparametrized_params:
        pytest.skip("{} not reparametrized.".format(jax_class.__name__))

    nonrepara_params_dict = {
        k: v for k, v in params_dict.items() if k not in reparametrized_params
    }
    repara_params = tuple(
        v for k, v in params_dict.items() if k in reparametrized_params
    )

    rng_key = random.PRNGKey(0)

    def fn(args):
        args_dict = dict(zip(reparametrized_params, args))
        return jnp.sum(
            jax_dist(**args_dict, **nonrepara_params_dict).sample(key=rng_key)
        )

    actual_grad = jax.grad(fn)(repara_params)
    assert len(actual_grad) == len(repara_params)

    eps = 1e-3
    for i in range(len(repara_params)):
        if repara_params[i] is None:
            continue
        args_lhs = [p if j != i else p - eps for j, p in enumerate(repara_params)]
        args_rhs = [p if j != i else p + eps for j, p in enumerate(repara_params)]
        fn_lhs = fn(args_lhs)
        fn_rhs = fn(args_rhs)
        # finite diff approximation
        expected_grad = (fn_rhs - fn_lhs) / (2.0 * eps)
        assert jnp.shape(actual_grad[i]) == jnp.shape(repara_params[i])
        assert_allclose(jnp.sum(actual_grad[i]), expected_grad, rtol=0.02, atol=0.03)


@pytest.mark.parametrize(
    "jax_dist, params",
    [
        (dist.Gamma, (1.0,)),
        (dist.Gamma, (0.1,)),
        (dist.Gamma, (10.0,)),
        (dist.Chi2, (1.0,)),
        (dist.Chi2, (0.1,)),
        (dist.Chi2, (10.0,)),
        (dist.Beta, (1.0, 1.0)),
        (dist.StudentT, (5.0, 2.0, 4.0)),
    ],
)
def test_pathwise_gradient(jax_dist, params):
    rng_key = random.PRNGKey(0)
    N = 1000000

    def f(params):
        z = jax_dist(*params).sample(key=rng_key, sample_shape=(N,))
        return (z + z**2).mean(0)

    def g(params):
        d = jax_dist(*params)
        return d.mean + d.variance + d.mean**2

    actual_grad = grad(f)(params)
    expected_grad = grad(g)(params)
    assert_allclose(actual_grad, expected_grad, rtol=0.005)


@pytest.mark.parametrize(
    "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
def test_jit_log_likelihood(jax_dist, sp_dist, params):
    if jax_dist.__name__ in (
        "EulerMaruyama",
        "GaussianRandomWalk",
        "GaussianStateSpace",
        "_ImproperWrapper",
        "LKJ",
        "LKJCholesky",
        "_SparseCAR",
        "ZeroSumNormal",
    ):
        pytest.xfail(reason="non-jittable params")

    rng_key = random.PRNGKey(0)
    samples = jax_dist(*params).sample(key=rng_key, sample_shape=(2, 3))

    def log_likelihood(*params):
        return jax_dist(*params).log_prob(samples)

    expected = log_likelihood(*params)
    actual = jax.jit(log_likelihood)(*params)
    assert_allclose(actual, expected, atol=2e-5, rtol=2e-5)


@pytest.mark.parametrize(
    "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
@pytest.mark.parametrize("prepend_shape", [(), (2,), (2, 3)])
@pytest.mark.parametrize("jit", [False, True])
def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit):
    jit_fn = _identity if not jit else jax.jit
    jax_dist = jax_dist(*params)

    rng_key = random.PRNGKey(0)
    samples = jax_dist.sample(key=rng_key, sample_shape=prepend_shape)
    if isinstance(jax_dist, dist.IntervalCensoredDistribution):
        # IntervalCensoredDistribution takes interval (lo-hi) input but return univarite samples
        samples = jnp.stack([samples, samples + 0.1], axis=-1)
    assert jax_dist.log_prob(samples).shape == prepend_shape + jax_dist.batch_shape
    truncated_dists = (
        dist.LeftTruncatedDistribution,
        dist.RightTruncatedDistribution,
        dist.TwoSidedTruncatedDistribution,
    )
    if sp_dist is None:
        if isinstance(jax_dist, truncated_dists):
            if isinstance(params[0], dist.Distribution):
                # new api
                loc, scale, low, high = (
                    params[0].loc,
                    params[0].scale,
                    params[1],
                    params[2],
                )
            else:
                # old api
                loc, scale, low, high = params
            if low is None:
                low = -np.inf
            if high is None:
                high = np.inf
            sp_dist = get_sp_dist(type(jax_dist.base_dist))(loc, scale)
            expected = sp_dist.logpdf(samples) - jnp.log(
                sp_dist.cdf(high) - sp_dist.cdf(low)
            )
            assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5)
            return
        pytest.skip("no corresponding scipy distn.")
    if _is_batched_multivariate(jax_dist):
        pytest.skip("batching not allowed in multivariate distns.")
    if jax_dist.event_shape and prepend_shape:
        # >>> d = sp.dirichlet([1.1, 1.1])
        # >>> samples = d.rvs(size=(2,))
        # >>> d.logpdf(samples)
        # ValueError: The input vector 'x' must lie within the normal simplex ...
        pytest.skip("batched samples cannot be scored by multivariate distributions.")
    sp_dist = sp_dist(*params)
    try:
        expected = sp_dist.logpdf(samples)
    except AttributeError:
        expected = sp_dist.logpmf(samples)
    except ValueError as e:
        # precision issue: jnp.sum(x / jnp.sum(x)) = 0.99999994 != 1
        if "The input vector 'x' must lie within the normal simplex." in str(e):
            samples = jax.device_get(samples).astype("float64")
            samples = samples / samples.sum(axis=-1, keepdims=True)
            expected = sp_dist.logpdf(samples)
        else:
            raise e
    assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5)


@pytest.mark.parametrize(
    "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
def test_entropy_scipy(jax_dist, sp_dist, params):
    jax_dist = jax_dist(*params)

    try:
        actual = jax_dist.entropy()
    except NotImplementedError:
        pytest.skip(reason=f"distribution {jax_dist} does not implement `entropy`")
    if _is_batched_multivariate(jax_dist):
        pytest.skip("batching not allowed in multivariate distns.")
    if sp_dist is None:
        pytest.skip(reason="no corresponding scipy distribution")

    sp_dist = sp_dist(*params)
    expected = sp_dist.entropy()
    assert_allclose(actual, expected, atol=1e-5)


@pytest.mark.parametrize(
    "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL + BASE
)
def test_entropy_samples(jax_dist, sp_dist, params):
    jax_dist = jax_dist(*params)

    try:
        actual = jax_dist.entropy()
    except NotImplementedError:
        pytest.skip(reason=f"distribution {jax_dist} does not implement `entropy`")

    samples = jax_dist.sample(jax.random.key(8), (1000,))
    neg_log_probs = -jax_dist.log_prob(samples)
    mean = neg_log_probs.mean(axis=0)
    neg_log_probs_std = neg_log_probs.std(axis=0)
    safe_neg_log_probs_std = jnp.where(
        jnp.equal(neg_log_probs_std, 0.0),
        jnp.finfo(jnp.result_type(float)).tiny,
        neg_log_probs_std,
    )
    stderr = safe_neg_log_probs_std / jnp.sqrt(neg_log_probs.shape[-1] - 1)
    z = (actual - mean) / stderr

    # Check the z-score is small or that all values are close. This happens, for
    # example, for uniform distributions with constant log prob and hence zero stderr.
    assert (jnp.abs(z) < 5).all() or jnp.allclose(actual, neg_log_probs, atol=1e-5)


def test_entropy_categorical():
    # There is no scipy mapping for categorical distributions, but the multinomial with
    # one trial has the same entropy--which we check here.
    logits = jax.random.normal(jax.random.key(9), (7,))
    probs = _to_probs_multinom(logits)
    sp_dist = osp.multinomial(1, probs)
    for jax_dist in [dist.CategoricalLogits(logits), dist.CategoricalProbs(probs)]:
        assert_allclose(jax_dist.entropy(), sp_dist.entropy(), rtol=1e-6, atol=1e-6)


def test_mixture_log_prob():
    gmm = dist.MixtureSameFamily(
        dist.Categorical(logits=np.zeros(2)), dist.Normal(0, 1).expand([2])
    )
    actual = gmm.log_prob(0.0)
    expected = dist.Normal(0, 1).log_prob(0.0)
    assert_allclose(actual, expected)


@pytest.mark.parametrize(
    "jax_dist, sp_dist, params",
    # TODO: add more complete pattern for Discrete.cdf
    CONTINUOUS + [T(dist.Poisson, 2.0), T(dist.Poisson, np.array([2.0, 3.0, 5.0]))],
)
@pytest.mark.filterwarnings("ignore:overflow encountered:RuntimeWarning")
def test_cdf_and_icdf(jax_dist, sp_dist, params):
    d = jax_dist(*params)
    if d.event_dim > 0:
        pytest.skip("skip testing cdf/icdf methods of multivariate distributions")
    samples = d.sample(key=random.PRNGKey(0), sample_shape=(100,))
    quantiles = random.uniform(random.PRNGKey(1), (100,) + d.shape())
    try:
        rtol = (
            2e-3
            if jax_dist
            in (
                _TruncatedCauchy,
                _TruncatedNormal,
                dist.Gamma,
                dist.LogNormal,
                dist.StudentT,
            )
            else 1e-5
        )
        if d.shape() == () and not d.is_discrete:
            assert_allclose(
                jax.vmap(jax.grad(d.cdf))(samples),
                jnp.exp(d.log_prob(samples)),
                atol=1e-5,
                rtol=rtol,
            )
            assert_allclose(
                jax.vmap(jax.grad(d.icdf))(quantiles),
                jnp.exp(-d.log_prob(d.icdf(quantiles))),
                atol=1e-5,
                rtol=rtol,
            )
        assert_allclose(d.cdf(d.icdf(quantiles)), quantiles, atol=1e-5, rtol=1e-5)
        assert_allclose(d.icdf(d.cdf(samples)), samples, atol=1e-5, rtol=rtol)
    except NotImplementedError:
        pytest.skip("cdf/icdf not implemented")

    # test against scipy
    if not sp_dist:
        pytest.skip("no corresponding scipy distn.")
    sp_dist = sp_dist(*params)
    try:
        actual_cdf = d.cdf(samples)
        expected_cdf = sp_dist.cdf(samples)
        assert_allclose(actual_cdf, expected_cdf, atol=1e-5, rtol=1e-5)
        actual_icdf = d.icdf(quantiles)
        expected_icdf = sp_dist.ppf(quantiles)
        assert_allclose(actual_icdf, expected_icdf, atol=1e-4, rtol=1e-4)
    except NotImplementedError:
        pytest.skip("cdf/icdf not implemented")


@pytest.mark.parametrize("jax_dist, sp_dist, params", CONTINUOUS + DISCRETE)
def test_independent_shape(jax_dist, sp_dist, params):
    d = jax_dist(*params)
    batch_shape, event_shape = d.batch_shape, d.event_shape
    shape = batch_shape + event_shape
    for i in range(len(batch_shape)):
        indep = dist.Independent(d, reinterpreted_batch_ndims=i)
        sample = indep.sample(random.PRNGKey(0))
        event_boundary = len(shape) - len(event_shape) - i
        assert indep.batch_shape == shape[:event_boundary]
        assert indep.event_shape == shape[event_boundary:]
        assert jnp.shape(indep.log_prob(sample)) == shape[:event_boundary]


def _tril_cholesky_to_tril_corr(x):
    w = vec_to_tril_matrix(x, diagonal=-1)
    diag = jnp.sqrt(1 - jnp.sum(w**2, axis=-1))
    cholesky = w + jnp.expand_dims(diag, axis=-1) * jnp.identity(w.shape[-1])
    corr = jnp.matmul(cholesky, cholesky.T)
    return matrix_to_tril_vec(corr, diagonal=-1)


@pytest.mark.parametrize("dimension", [2, 3, 5])
def test_log_prob_LKJCholesky_uniform(dimension):
    # When concentration=1, the distribution of correlation matrices is uniform.
    # We will test that fact here.
    d = dist.LKJCholesky(dimension=dimension, concentration=1)
    N = 5
    corr_log_prob = []
    for i in range(N):
        sample = d.sample(random.PRNGKey(i))
        log_prob = d.log_prob(sample)
        sample_tril = matrix_to_tril_vec(sample, diagonal=-1)
        cholesky_to_corr_jac = np.linalg.slogdet(
            jax.jacobian(_tril_cholesky_to_tril_corr)(sample_tril)
        )[1]
        corr_log_prob.append(log_prob - cholesky_to_corr_jac)

    corr_log_prob = np.array(corr_log_prob)
    # test if they are constant
    assert_allclose(
        corr_log_prob,
        jnp.broadcast_to(corr_log_prob[0], corr_log_prob.shape),
        rtol=1e-6,
    )

    if dimension == 2:
        # when concentration = 1, LKJ gives a uniform distribution over correlation matrix,
        # hence for the case dimension = 2,
        # density of a correlation matrix will be Uniform(-1, 1) = 0.5.
        # In addition, jacobian of the transformation from cholesky -> corr is 1 (hence its
        # log value is 0) because the off-diagonal lower triangular element does not change
        # in the transform.
        # So target_log_prob = log(0.5)
        assert_allclose(corr_log_prob[0], jnp.log(0.5), rtol=1e-6)


@pytest.mark.parametrize("dimension", [2, 3, 5])
@pytest.mark.parametrize("concentration", [0.6, 2.2])
def test_log_prob_LKJCholesky(dimension, concentration):
    # We will test against the fact that LKJCorrCholesky can be seen as a
    # TransformedDistribution with base distribution is a distribution of partial
    # correlations in C-vine method (modulo an affine transform to change domain from (0, 1)
    # to (1, 0)) and transform is a signed stick-breaking process.
    d = dist.LKJCholesky(dimension, concentration, sample_method="cvine")

    beta_sample = d._beta.sample(random.PRNGKey(0))
    beta_log_prob = jnp.sum(d._beta.log_prob(beta_sample))
    partial_correlation = 2 * beta_sample - 1
    affine_logdet = beta_sample.shape[-1] * jnp.log(2)
    sample = signed_stick_breaking_tril(partial_correlation)

    # compute signed stick breaking logdet
    inv_tanh = lambda t: jnp.log((1 + t) / (1 - t)) / 2  # noqa: E731
    inv_tanh_logdet = jnp.sum(jnp.log(vmap(grad(inv_tanh))(partial_correlation)))
    unconstrained = inv_tanh(partial_correlation)
    corr_cholesky_logdet = biject_to(constraints.corr_cholesky).log_abs_det_jacobian(
        unconstrained, sample
    )
    signed_stick_breaking_logdet = corr_cholesky_logdet + inv_tanh_logdet

    actual_log_prob = d.log_prob(sample)
    expected_log_prob = beta_log_prob - affine_logdet - signed_stick_breaking_logdet
    assert_allclose(actual_log_prob, expected_log_prob, rtol=2e-5)

    assert_allclose(jax.jit(d.log_prob)(sample), d.log_prob(sample), atol=2e-6)


def test_zero_inflated_logits_probs_agree():
    concentration = np.exp(np.random.normal(1))
    rate = np.exp(np.random.normal(1))
    d = dist.GammaPoisson(concentration, rate)
    gate_logits = np.random.normal(0)
    gate_probs = expit(gate_logits)
    zi_logits = dist.ZeroInflatedDistribution(d, gate_logits=gate_logits)
    zi_probs = dist.ZeroInflatedDistribution(d, gate=gate_probs)
    sample = np.random.randint(0, 20, (1000, 100))
    assert_allclose(zi_probs.log_prob(sample), zi_logits.log_prob(sample))


@pytest.mark.parametrize("rate", [0.1, 0.5, 0.9, 1.0, 1.1, 2.0, 10.0])
def test_ZIP_log_prob(rate):
    # if gate is 0 ZIP is Poisson
    zip_ = dist.ZeroInflatedPoisson(0.0, rate)
    pois = dist.Poisson(rate)
    s = zip_.sample(random.PRNGKey(0), (20,))
    zip_prob = zip_.log_prob(s)
    pois_prob = pois.log_prob(s)
    assert_allclose(zip_prob, pois_prob, rtol=1e-6)

    # if gate is 1 ZIP is Delta(0)
    zip_ = dist.ZeroInflatedPoisson(1.0, rate)
    delta = dist.Delta(0.0)
    s = np.array([0.0, 1.0])
    zip_prob = zip_.log_prob(s)
    delta_prob = delta.log_prob(s)
    assert_allclose(zip_prob, delta_prob, rtol=1e-6)


@pytest.mark.parametrize("total_count", [1, 2, 3, 10])
@pytest.mark.parametrize("shape", [(1,), (3, 1), (2, 3, 1)])
def test_beta_binomial_log_prob(total_count, shape):
    concentration0 = np.exp(np.random.normal(size=shape))
    concentration1 = np.exp(np.random.normal(size=shape))
    value = jnp.arange(1 + total_count)

    num_samples = 100000
    probs = np.random.beta(concentration1, concentration0, size=(num_samples,) + shape)
    log_probs = dist.Binomial(total_count, probs).log_prob(value)
    expected = logsumexp(log_probs, 0) - jnp.log(num_samples)

    actual = dist.BetaBinomial(concentration1, concentration0, total_count).log_prob(
        value
    )
    assert_allclose(actual, expected, rtol=0.02)


@pytest.mark.parametrize("n", [1, 2, 5, 10])
@pytest.mark.parametrize("shape", [(1,), (3, 1), (2, 3, 1)])
def test_beta_negative_binomial_log_prob(n, shape):
    concentration0 = np.exp(np.random.normal(size=shape))
    concentration1 = np.exp(np.random.normal(size=shape))
    value = jnp.arange(15)

    num_samples = 300000
    probs = np.random.beta(concentration1, concentration0, size=(num_samples,) + shape)
    # Use NegativeBinomialProbs: total_count=n (number of successes), probs=p
    log_probs = dist.NegativeBinomialProbs(n, probs).log_prob(value)
    expected = logsumexp(log_probs, 0) - jnp.log(num_samples)

    actual = dist.BetaNegativeBinomial(concentration1, concentration0, n).log_prob(
        value
    )
    assert_allclose(actual, expected, rtol=0.05)


@pytest.mark.parametrize("total_count", [1, 2, 3, 10])
@pytest.mark.parametrize("batch_shape", [(1,), (3, 1), (2, 3, 1)])
def test_dirichlet_multinomial_log_prob(total_count, batch_shape):
    event_shape = (3,)
    concentration = np.exp(np.random.normal(size=batch_shape + event_shape))
    # test on one-hots
    value = total_count * jnp.eye(event_shape[-1]).reshape(
        event_shape + (1,) * len(batch_shape) + event_shape
    )

    num_samples = 100000
    probs = dist.Dirichlet(concentration).sample(random.PRNGKey(0), (num_samples, 1))
    log_probs = dist.Multinomial(total_count, probs).log_prob(value)
    expected = logsumexp(log_probs, 0) - jnp.log(num_samples)

    actual = dist.DirichletMultinomial(concentration, total_count).log_prob(value)
    assert_allclose(actual, expected, rtol=0.05)


@pytest.mark.parametrize("shape", [(1,), (3, 1), (2, 3, 1)])
def test_gamma_poisson_log_prob(shape):
    gamma_conc = np.exp(np.random.normal(size=shape))
    gamma_rate = np.exp(np.random.normal(size=shape))
    value = jnp.arange(15)

    num_samples = 300000
    poisson_rate = np.random.gamma(
        gamma_conc, 1 / gamma_rate, size=(num_samples,) + shape
    )
    log_probs = dist.Poisson(poisson_rate).log_prob(value)
    expected = logsumexp(log_probs, 0) - jnp.log(num_samples)
    actual = dist.GammaPoisson(gamma_conc, gamma_rate).log_prob(value)
    assert_allclose(actual, expected, rtol=0.05)


@pytest.mark.parametrize("conc", [15.0, 20.0, 30.0])
def test_inverse_wishart_variance(conc):
    """Test InverseWishart variance formula against Monte Carlo samples.

    The variance formula is:
    Var(X_ij) = (Psi_ij^2 + Psi_ii * Psi_jj) / ((nu - p - 1)^2 * (nu - p - 3))
    """
    # Use a non-trivial scale matrix
    scale = np.array([[2.0, 0.5], [0.5, 1.0]])
    p = scale.shape[0]

    # conc must be > p + 3 for variance to exist
    assert conc > p + 3

    # Create distribution
    iw = dist.InverseWishart(conc, scale_matrix=scale)

    # Sample and compute empirical variance
    key = random.PRNGKey(42)
    num_samples = 200000
    samples = iw.sample(key, sample_shape=(num_samples,))
    empirical_var = jnp.var(samples, axis=0)

    # Get analytical variance
    analytical_var = iw.variance

    # Check that variance formula matches empirical variance
    # Use larger tolerance for smaller conc (more variance in estimates)
    rtol = 0.15 if conc < 20 else 0.10
    assert_allclose(analytical_var, empirical_var, rtol=rtol)


@pytest.mark.parametrize(
    "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
def test_log_prob_gradient(jax_dist, sp_dist, params):
    if jax_dist in [dist.LKJ, dist.LKJCholesky]:
        pytest.skip("we have separated tests for LKJCholesky distribution")
    if jax_dist is _ImproperWrapper:
        pytest.skip("no param for ImproperUniform to test for log_prob gradient")
    if (
        jax_dist in [dist.DoublyTruncatedPowerLaw, dist.Dagum]
        and jnp.result_type(float) == jnp.float32
    ):
        pytest.skip(f"{jax_dist.__name__} is tested with x64 only.")

    rng_key = random.PRNGKey(0)
    value = jax_dist(*params).sample(rng_key)
    if isinstance(jax_dist(*params), dist.IntervalCensoredDistribution):
        # IntervalCensoredDistribution takes interval (lo-hi) input but returns univarite samples
        value = jnp.stack([value, value + 0.1], axis=-1)

    def fn(*args):
        return jnp.sum(jax_dist(*args).log_prob(value))

    eps = 1e-3
    atol = 0.01
    rtol = 0.05
    if jax_dist is dist.EulerMaruyama:
        atol = 0.064
    elif jax_dist is dist.NegativeBinomialLogits:
        atol = 0.013

    for i in range(len(params)):
        if jax_dist is dist.EulerMaruyama and i == 1:
            # skip taking grad w.r.t. sde_fn
            continue
        if jax_dist is _SparseCAR and i == 3:
            # skip taking grad w.r.t. adj_matrix
            continue
        if jax_dist is dist.ZeroSumNormal and i != 0:
            # skip taking grad w.r.t. event_shape
            continue
        if isinstance(
            params[i], dist.Distribution
        ):  # skip taking grad w.r.t. base_dist
            continue
        if jax_dist is dist.DoublyTruncatedPowerLaw and i != 0:
            continue
        if params[i] is None or jnp.result_type(params[i]) in (jnp.int32, jnp.int64):
            continue
        actual_grad = jax.grad(fn, i)(*params)
        args_lhs = [p if j != i else p - eps for j, p in enumerate(params)]
        args_rhs = [p if j != i else p + eps for j, p in enumerate(params)]
        fn_lhs = fn(*args_lhs)
        fn_rhs = fn(*args_rhs)
        # finite diff approximation
        expected_grad = (fn_rhs - fn_lhs) / (2.0 * eps)
        assert jnp.shape(actual_grad) == jnp.shape(params[i])
        if i == 0 and jax_dist is dist.Delta:
            # grad w.r.t. `value` of Delta distribution will be 0
            # but numerical value will give nan (= inf - inf)
            expected_grad = 0.0
        assert_allclose(jnp.sum(actual_grad), expected_grad, rtol=rtol, atol=atol)


@pytest.mark.parametrize(
    "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
@pytest.mark.xfail(
    os.getenv("JAX_CHECK_TRACER_LEAKS") == "1",
    reason="Expected tracer leak for Gompertz: https://github.com/jax-ml/jax/issues/26972",
)
def test_mean_var(jax_dist, sp_dist, params):
    if jax_dist is _ImproperWrapper:
        pytest.skip("Improper distribution does not has mean/var implemented")
    if jax_dist is FoldedNormal:
        pytest.skip("Folded distribution does not has mean/var implemented")
    if jax_dist is dist.EulerMaruyama:
        pytest.skip("EulerMaruyama distribution does not has mean/var implemented")
    if jax_dist is dist.RelaxedBernoulliLogits:
        pytest.skip("RelaxedBernoulli distribution does not has mean/var implemented")
    if "SineSkewed" in jax_dist.__name__:
        pytest.skip("Skewed Distribution are not symmetric about location.")
    if jax_dist in (
        _TruncatedNormal,
        _TruncatedCauchy,
        dist.LeftTruncatedDistribution,
        dist.RightTruncatedDistribution,
        dist.TwoSidedTruncatedDistribution,
    ):
        pytest.skip("Truncated distributions do not has mean/var implemented")
    if jax_dist in (
        _LeftCensoredHalfNormal,
        _RightCensoredWeibull,
        _LeftCensoredNormal,
        _RightCensoredNormal,
        _LeftCensoredPoisson,
        _RightCensoredPoisson,
        _IntervalCensoredNormal,
        _IntervalCensoredWeibull,
        _IntervalCensoredPoisson,
        dist.LeftCensoredDistribution,
        dist.RightCensoredDistribution,
        dist.IntervalCensoredDistribution,
    ):
        pytest.skip("Censored distributions do not have mean/var implemented")
    if jax_dist is dist.ProjectedNormal:
        pytest.skip("Mean is defined in submanifold")
    if jax_dist is dist.InverseWishartCholesky:
        pytest.skip(
            "InverseWishartCholesky.mean is an approximation (Cholesky of mean matrix)"
        )
    if jax_dist is dist.InverseWishart:
        # Skip generic mean_var test; dedicated test_inverse_wishart_variance validates
        # the variance formula with appropriate df values and sample sizes
        pytest.skip(
            "InverseWishart variance tested in dedicated test_inverse_wishart_variance"
        )
    if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]:
        pytest.skip(
            f"{jax_dist.__name__} distribution does not has mean/var implemented"
        )
    if jax_dist in [dist.Dagum] and jnp.result_type(float) == jnp.float32:
        pytest.skip(f"{jax_dist.__name__} is tested with x64 only.")

    n = (
        20000
        if jax_dist in [dist.LKJ, dist.LKJCholesky, dist.SineBivariateVonMises]
        else 200000
    )
    d_jax = jax_dist(*params)
    k = random.PRNGKey(0)
    samples = d_jax.sample(k, sample_shape=(n,)).astype(np.float32)
    # check with suitable scipy implementation if available
    # XXX: VonMises is already tested below
    if (
        sp_dist
        and not _is_batched_multivariate(d_jax)
        and jax_dist
        not in [dist.VonMises, dist.MultivariateStudentT, dist.MatrixNormal]
    ):
        d_sp = sp_dist(*params)
        try:
            sp_mean = d_sp.mean()
        except TypeError:  # mvn does not have .mean() method
            sp_mean = d_sp.mean
        # for multivariate distns try .cov first
        if d_jax.event_shape:
            try:
                sp_var = jnp.diag(d_sp.cov())
            except TypeError:  # mvn does not have .cov() method
                sp_var = jnp.diag(d_sp.cov)
            except (AttributeError, ValueError):
                sp_var = d_sp.var()
        else:
            sp_var = d_sp.var()
        assert_allclose(d_jax.mean, sp_mean, rtol=0.01, atol=1e-7)
        if sp_var is not None:
            assert_allclose(d_jax.variance, sp_var, rtol=0.01, atol=1e-7)
        if jnp.all(jnp.isfinite(sp_mean)):
            assert_allclose(jnp.mean(samples, 0), d_jax.mean, rtol=0.05, atol=1e-2)
        if sp_var is not None and jnp.all(jnp.isfinite(sp_var)):
            assert_allclose(
                jnp.std(samples, 0), jnp.sqrt(d_jax.variance), rtol=0.06, atol=1e-2
            )
    elif jax_dist in [dist.LKJ, dist.LKJCholesky]:
        if jax_dist is dist.LKJCholesky:
            corr_samples = jnp.matmul(samples, jnp.swapaxes(samples, -2, -1))
        else:
            corr_samples = samples
        dimension, concentration, _ = params
        # marginal of off-diagonal entries
        marginal = dist.Beta(
            concentration + 0.5 * (dimension - 2), concentration + 0.5 * (dimension - 2)
        )
        # scale statistics due to linear mapping
        marginal_mean = 2 * marginal.mean - 1
        marginal_std = 2 * jnp.sqrt(marginal.variance)
        expected_mean = jnp.broadcast_to(
            jnp.reshape(marginal_mean, jnp.shape(marginal_mean) + (1, 1)),
            jnp.shape(marginal_mean) + d_jax.event_shape,
        )
        expected_std = jnp.broadcast_to(
            jnp.reshape(marginal_std, jnp.shape(marginal_std) + (1, 1)),
            jnp.shape(marginal_std) + d_jax.event_shape,
        )
        # diagonal elements of correlation matrices are 1
        expected_mean = expected_mean * (1 - jnp.identity(dimension)) + jnp.identity(
            dimension
        )
        expected_std = expected_std * (1 - jnp.identity(dimension))

        assert_allclose(jnp.mean(corr_samples, axis=0), expected_mean, atol=0.02)
        assert_allclose(jnp.std(corr_samples, axis=0), expected_std, atol=0.02)
    elif jax_dist in [dist.VonMises]:
        # circular mean = sample mean
        assert_allclose(d_jax.mean, jnp.mean(samples, 0), rtol=0.05, atol=1e-2)

        # circular variance
        x, y = jnp.mean(jnp.cos(samples), 0), jnp.mean(jnp.sin(samples), 0)

        expected_variance = 1 - jnp.sqrt(x**2 + y**2)
        assert_allclose(d_jax.variance, expected_variance, rtol=0.05, atol=1e-2)
    elif jax_dist in [dist.SineBivariateVonMises]:
        phi_loc = _circ_mean(samples[..., 0])
        psi_loc = _circ_mean(samples[..., 1])

        assert_allclose(
            d_jax.mean, jnp.stack((phi_loc, psi_loc), axis=-1), rtol=0.05, atol=1e-2
        )
    elif jax_dist in [dist.MatrixNormal]:
        sample_shape = (200_000,)
        # use X ~ MN(loc, U, V) then vec(X) ~ MVN(vec(loc), kron(V, U))
        if len(d_jax.batch_shape) > 0:
            axes = [len(sample_shape) + i for i in range(len(d_jax.batch_shape))]
            axes = tuple(axes)
            samples_re = jnp.moveaxis(samples, axes, jnp.arange(len(axes)))
            subshape = samples_re.shape[: len(axes)]
            ixi = product(*[range(k) for k in subshape])
            for ix in ixi:
                # mean
                def get_min_shape(ix, batch_shape):
                    return min(ix, tuple(map(lambda x: x - 1, batch_shape)))

                ix_loc = get_min_shape(ix, d_jax.loc.shape[: len(ix)])
                jnp.allclose(
                    jnp.mean(samples_re[ix], 0),
                    jnp.squeeze(d_jax.mean[ix_loc]),
                    rtol=0.5,
                    atol=1e-2,
                )
                # cov
                samples_mvn = jnp.squeeze(samples_re[ix]).reshape(
                    sample_shape + (-1,), order="F"
                )
                ix_col = get_min_shape(ix, d_jax.scale_tril_column.shape[: len(ix)])
                ix_row = get_min_shape(ix, d_jax.scale_tril_row.shape[: len(ix)])
                scale_tril = my_kron(
                    d_jax.scale_tril_column[ix_col],
                    d_jax.scale_tril_row[ix_row],
                )
                sample_scale_tril = jnp.linalg.cholesky(jnp.cov(samples_mvn.T))
                jnp.allclose(sample_scale_tril, scale_tril, atol=0.5, rtol=1e-2)
        else:  # unbatched
            # mean
            jnp.allclose(
                jnp.mean(samples, 0),
                jnp.squeeze(d_jax.mean),
                rtol=0.5,
                atol=1e-2,
            )
            # cov
            samples_mvn = jnp.squeeze(samples).reshape(sample_shape + (-1,), order="F")
            scale_tril = my_kron(
                jnp.squeeze(d_jax.scale_tril_column), jnp.squeeze(d_jax.scale_tril_row)
            )
            sample_scale_tril = jnp.linalg.cholesky(jnp.cov(samples_mvn.T))
            jnp.allclose(sample_scale_tril, scale_tril, atol=0.5, rtol=1e-2)
    else:
        if jnp.all(jnp.isfinite(d_jax.mean)):
            assert_allclose(jnp.mean(samples, 0), d_jax.mean, rtol=0.05, atol=1e-2)
        if isinstance(d_jax, dist.CAR):
            pytest.skip("CAR distribution does not have `variance` implemented.")
        if isinstance(d_jax, dist.Gompertz):
            pytest.skip("Gompertz distribution does not have `variance` implemented.")
        if jnp.all(jnp.isfinite(d_jax.variance)):
            assert jnp.allclose(
                jnp.std(samples, 0), jnp.sqrt(d_jax.variance), rtol=0.05, atol=0.05
            )


@pytest.mark.parametrize(
    "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
@pytest.mark.parametrize("prepend_shape", [(), (2,), (2, 3)])
def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape):
    if jax_dist in (
        _TruncatedNormal,
        _TruncatedCauchy,
        _LeftCensoredHalfNormal,
        _RightCensoredWeibull,
        _LeftCensoredNormal,
        _RightCensoredNormal,
        _LeftCensoredPoisson,
        _RightCensoredPoisson,
        _IntervalCensoredNormal,
        _IntervalCensoredWeibull,
        _IntervalCensoredPoisson,
        _GaussianMixture,
        _Gaussian2DMixture,
        _GeneralMixture,
        _General2DMixture,
    ):
        pytest.skip(f"{jax_dist.__name__} is a function, not a class")

    dist_args = [p for p in inspect.getfullargspec(jax_dist.__init__)[0][1:]]

    valid_params, oob_params = list(params), list(params)
    key = random.PRNGKey(1)
    dependent_constraint = False
    for i in range(len(params)):
        if (
            jax_dist in (_ImproperWrapper, dist.LKJ, dist.LKJCholesky)
            and dist_args[i] != "concentration"
        ):
            continue
        if "SineSkewed" in jax_dist.__name__ and dist_args[i] != "skewness":
            continue
        if jax_dist is dist.EulerMaruyama and dist_args[i] != "t":
            continue
        if (
            jax_dist is dist.TwoSidedTruncatedDistribution
            and dist_args[i] == "base_dist"
        ):
            continue
        if (
            issubclass(jax_dist, (dist.GaussianRandomWalk, dist.GaussianStateSpace))
            and dist_args[i] == "num_steps"
        ):
            continue
        if jax_dist is dist.ZeroSumNormal and dist_args[i] == "event_shape":
            continue
        if (
            jax_dist is dist.SineBivariateVonMises
            and dist_args[i] == "weighted_correlation"
        ):
            continue
        if params[i] is None:
            oob_params[i] = None
            valid_params[i] = None
            continue
        constraint = jax_dist.arg_constraints[dist_args[i]]
        if isinstance(constraint, constraints._Dependent):
            dependent_constraint = True
            break
        key, key_gen = random.split(key)
        oob_params[i] = gen_values_outside_bounds(
            constraint, jnp.shape(params[i]), key_gen
        )
        valid_params[i] = gen_values_within_bounds(
            constraint, jnp.shape(params[i]), key_gen
        )
        if jax_dist is dist.MultivariateStudentT:
            # As mean is only defined for df > 1 & we instantiate
            # scipy.stats.multivariate_t with same mean as jax_dist
            # we need to ensure this is defined, so force df >= 1
            valid_params[0] += 1

        if jax_dist is dist.LogUniform:
            # scipy.stats.loguniform take parameter a and b
            # which is a > 0 and b > a.
            # gen_values_within_bounds() generates just
            # a > 0 and b > 0. Then, make b = a + b.
            valid_params[1] += valid_params[0]

    assert jax_dist(*oob_params)

    # Invalid parameter values throw ValueError
    if not dependent_constraint and (
        jax_dist is not _ImproperWrapper and "SineSkewed" not in jax_dist.__name__
    ):
        with pytest.raises(ValueError):
            jax_dist(*oob_params, validate_args=True)

        with pytest.raises(ValueError):
            # test error raised under jit omnistaging
            oob_params = jax.device_get(oob_params)

            def dist_gen_fn():
                d = jax_dist(*oob_params, validate_args=True)
                return d

            jax.jit(dist_gen_fn)()

    d = jax_dist(*valid_params, validate_args=True)

    # Test agreement of log density evaluation on randomly generated samples
    # with scipy's implementation when available.
    if (
        sp_dist
        and not _is_batched_multivariate(d)
        and not (d.event_shape and prepend_shape)
    ):
        valid_samples = gen_values_within_bounds(
            d.support, size=prepend_shape + d.batch_shape + d.event_shape
        )
        try:
            expected = sp_dist(*valid_params).logpdf(valid_samples)
        except AttributeError:
            expected = sp_dist(*valid_params).logpmf(valid_samples)
        assert_allclose(d.log_prob(valid_samples), expected, atol=1e-5, rtol=1e-5)

    # Out of support samples throw ValueError
    oob_samples = gen_values_outside_bounds(
        d.support, size=prepend_shape + d.batch_shape + d.event_shape
    )
    with pytest.warns(UserWarning, match="Out-of-support"):
        d.log_prob(oob_samples)

    with pytest.warns(UserWarning, match="Out-of-support"):
        # test warning work under jit omnistaging
        oob_samples = jax.device_get(oob_samples)
        valid_params = jax.device_get(valid_params)

        def log_prob_fn():
            d = jax_dist(*valid_params, validate_args=True)
            return d.log_prob(oob_samples)

        jax.jit(log_prob_fn)()


def test_omnistaging_invalid_param():
    def f(x):
        return dist.LogNormal(x, -np.ones(2), validate_args=True).log_prob(0)

    with pytest.raises(ValueError, match="got invalid"):
        jax.jit(f)(0)


def test_omnistaging_invalid_sample():
    def f(x):
        return dist.LogNormal(x, np.ones(2), validate_args=True).log_prob(-1)

    with pytest.warns(UserWarning, match="Out-of-support"):
        jax.jit(f)(0)


def test_categorical_log_prob_grad():
    data = jnp.repeat(jnp.arange(3), 10)

    def f(x):
        return (
            dist.Categorical(jax.nn.softmax(x * jnp.arange(1, 4))).log_prob(data).sum()
        )

    def g(x):
        return dist.Categorical(logits=x * jnp.arange(1, 4)).log_prob(data).sum()

    x = 0.5
    fx, grad_fx = jax.value_and_grad(f)(x)
    gx, grad_gx = jax.value_and_grad(g)(x)
    assert_allclose(fx, gx, rtol=1e-6)
    assert_allclose(grad_fx, grad_gx, atol=1e-4)


def test_beta_proportion_invalid_mean():
    with (
        dist.distribution.validation_enabled(),
        pytest.raises(
            ValueError,
            match=r"^BetaProportion distribution got invalid mean parameter\.$",
        ),
    ):
        dist.BetaProportion(1.0, 1.0)


########################################
# Tests for constraints and transforms #
########################################


@pytest.mark.parametrize(
    "constraint, x, expected",
    [
        (constraints.boolean, np.array([True, False]), np.array([True, True])),
        (constraints.boolean, np.array([1, 1]), np.array([True, True])),
        (constraints.boolean, np.array([-1, 1]), np.array([False, True])),
        (
            constraints.corr_cholesky,
            np.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]),
            np.array([True, False]),
        ),  # NB: not lower_triangular
        (
            constraints.corr_cholesky,
            np.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]),
            np.array([False, False]),
        ),  # NB: not positive_diagonal & not unit_norm_row
        (
            constraints.corr_matrix,
            np.array([[[1, 0], [0, 1]], [[1, 0.1], [0, 1]]]),
            np.array([True, False]),
        ),  # NB: not lower_triangular
        (
            constraints.corr_matrix,
            np.array([[[1, 0], [1, 0]], [[1, 0], [0.5, 0.5]]]),
            np.array([False, False]),
        ),  # NB: not unit diagonal
        (constraints.greater_than(1), 3, True),
        (
            constraints.greater_than(1),
            np.array([-1, 1, 5]),
            np.array([False, False, True]),
        ),
        (constraints.integer_interval(-3, 5), 0, True),
        (
            constraints.integer_interval(-3, 5),
            np.array([-5, -3, 0, 1.1, 5, 7]),
            np.array([False, True, True, False, True, False]),
        ),
        (constraints.interval(-3, 5), 0, True),
        (
            constraints.interval(-3, 5),
            np.array([-5, -3, 0, 5, 7]),
            np.array([False, True, True, True, False]),
        ),
        (constraints.less_than(1), -2, True),
        (
            constraints.less_than(1),
            np.array([-1, 1, 5]),
            np.array([True, False, False]),
        ),
        (constraints.lower_cholesky, np.array([[1.0, 0.0], [-2.0, 0.1]]), True),
        (
            constraints.lower_cholesky,
            np.array([[[1.0, 0.0], [-2.0, -0.1]], [[1.0, 0.1], [2.0, 0.2]]]),
            np.array([False, False]),
        ),
        (constraints.nonnegative_integer, 3, True),
        (
            constraints.nonnegative_integer,
            np.array([-1.0, 0.0, 5.0]),
            np.array([False, True, True]),
        ),
        (constraints.positive, 3, True),
        (constraints.positive, np.array([-1, 0, 5]), np.array([False, False, True])),
        (constraints.positive_definite, np.array([[1.0, 0.3], [0.3, 1.0]]), True),
        (
            constraints.positive_definite,
            np.array([[[2.0, 0.4], [0.3, 2.0]], [[1.0, 0.1], [0.1, 0.0]]]),
            np.array([False, False]),
        ),
        (constraints.positive_integer, 3, True),
        (
            constraints.positive_integer,
            np.array([-1.0, 0.0, 5.0]),
            np.array([False, False, True]),
        ),
        (constraints.real, -1, True),
        (
            constraints.real,
            np.array([np.inf, -np.inf, np.nan, np.pi]),
            np.array([False, False, False, True]),
        ),
        (constraints.simplex, np.array([0.1, 0.3, 0.6]), True),
        (
            constraints.simplex,
            np.array([[0.1, 0.3, 0.6], [-0.1, 0.6, 0.5], [0.1, 0.6, 0.5]]),
            np.array([True, False, False]),
        ),
        (constraints.softplus_positive, 3, True),
        (
            constraints.softplus_positive,
            np.array([-1, 0, 5]),
            np.array([False, False, True]),
        ),
        (
            constraints.softplus_lower_cholesky,
            np.array([[1.0, 0.0], [-2.0, 0.1]]),
            True,
        ),
        (
            constraints.softplus_lower_cholesky,
            np.array([[[1.0, 0.0], [-2.0, -0.1]], [[1.0, 0.1], [2.0, 0.2]]]),
            np.array([False, False]),
        ),
        (constraints.unit_interval, 0.1, True),
        (
            constraints.unit_interval,
            np.array([-5, 0, 0.5, 1, 7]),
            np.array([False, True, True, True, False]),
        ),
        (
            constraints.sphere,
            np.array([[1, 0, 0], [0.5, 0.5, 0]]),
            np.array([True, False]),
        ),
        (
            constraints.open_interval(0.0, 1.0),
            np.array([-5, 0, 0.5, 1, 7]),
            np.array([False, False, True, False, False]),
        ),
    ],
)
def test_constraints(constraint, x, expected):
    v = constraint.feasible_like(x)
    if jnp.result_type(v) == "float32" or jnp.result_type(v) == "float64":
        assert not constraint.is_discrete
    assert_array_equal(constraint(x), expected)

    feasible_value = constraint.feasible_like(x)
    assert jnp.shape(feasible_value) == jnp.shape(x)
    assert_allclose(constraint(feasible_value), jnp.full(jnp.shape(expected), True))

    try:
        inverse = biject_to(constraint).inv(feasible_value)
    except NotImplementedError:
        pass
    else:
        assert_allclose(inverse, jnp.zeros_like(inverse), atol=2e-7)


@pytest.mark.parametrize(
    "constraint",
    [
        constraints.corr_cholesky,
        constraints.corr_matrix,
        constraints.greater_than(2),
        constraints.interval(-3, 5),
        constraints.l1_ball,
        constraints.less_than(1),
        constraints.lower_cholesky,
        constraints.scaled_unit_lower_cholesky,
        constraints.ordered_vector,
        constraints.positive,
        constraints.positive_definite,
        constraints.positive_ordered_vector,
        constraints.real,
        constraints.real_vector,
        constraints.simplex,
        constraints.softplus_positive,
        constraints.softplus_lower_cholesky,
        constraints.unit_interval,
        constraints.open_interval(0.0, 1.0),
    ],
    ids=lambda x: x.__class__,
)
@pytest.mark.parametrize("shape", [(), (1,), (3,), (6,), (3, 1), (1, 3), (5, 3)])
def test_biject_to(constraint, shape):
    transform = biject_to(constraint)
    event_dim = transform.domain.event_dim
    if isinstance(constraint, constraints._Interval):
        assert transform.codomain.upper_bound == constraint.upper_bound
        assert transform.codomain.lower_bound == constraint.lower_bound
    elif isinstance(constraint, constraints._GreaterThan):
        assert transform.codomain.lower_bound == constraint.lower_bound
    elif isinstance(constraint, constraints._LessThan):
        assert transform.codomain.upper_bound == constraint.upper_bound
    if len(shape) < event_dim:
        return
    rng_key = random.PRNGKey(0)
    x = random.normal(rng_key, shape)
    y = transform(x)

    assert transform.forward_shape(x.shape) == y.shape
    assert transform.inverse_shape(y.shape) == x.shape

    # test inv work for NaN arrays:
    x_nan = transform.inv(jnp.full(jnp.shape(y), np.nan))
    assert x_nan.shape == x.shape

    # test codomain
    batch_shape = shape if event_dim == 0 else shape[:-1]
    assert_array_equal(transform.codomain(y), jnp.ones(batch_shape, dtype=jnp.bool_))

    # test inv
    z = transform.inv(y)
    assert_allclose(x, z, atol=1e-4, rtol=1e-5)

    # test domain, currently all is constraints.real or constraints.real_vector
    assert_array_equal(transform.domain(z), jnp.ones(batch_shape))

    # test log_abs_det_jacobian
    actual = transform.log_abs_det_jacobian(x, y)
    assert jnp.shape(actual) == batch_shape
    if len(shape) == event_dim:
        if constraint is constraints.simplex:
            expected = np.linalg.slogdet(jax.jacobian(transform)(x)[:-1, :])[1]
            inv_expected = np.linalg.slogdet(jax.jacobian(transform.inv)(y)[:, :-1])[1]
        elif constraint in [
            constraints.real_vector,
            constraints.ordered_vector,
            constraints.positive_ordered_vector,
            constraints.l1_ball,
        ]:
            expected = np.linalg.slogdet(jax.jacobian(transform)(x))[1]
            inv_expected = np.linalg.slogdet(jax.jacobian(transform.inv)(y))[1]
        elif constraint in [constraints.corr_cholesky, constraints.corr_matrix]:
            vec_transform = lambda x: matrix_to_tril_vec(  # noqa: E731
                transform(x), diagonal=-1
            )
            y_tril = matrix_to_tril_vec(y, diagonal=-1)

            def inv_vec_transform(y):
                matrix = vec_to_tril_matrix(y, diagonal=-1)
                if constraint is constraints.corr_matrix:
                    # fill the upper triangular part
                    matrix = (
                        matrix
                        + jnp.swapaxes(matrix, -2, -1)
                        + jnp.identity(matrix.shape[-1])
                    )
                return transform.inv(matrix)

            expected = np.linalg.slogdet(jax.jacobian(vec_transform)(x))[1]
            inv_expected = np.linalg.slogdet(jax.jacobian(inv_vec_transform)(y_tril))[1]
        elif constraint in [
            constraints.lower_cholesky,
            constraints.scaled_unit_lower_cholesky,
            constraints.positive_definite,
            constraints.softplus_lower_cholesky,
        ]:
            vec_transform = lambda x: matrix_to_tril_vec(transform(x))  # noqa: E731
            y_tril = matrix_to_tril_vec(y)

            def inv_vec_transform(y):
                matrix = vec_to_tril_matrix(y)
                if constraint is constraints.positive_definite:
                    # fill the upper triangular part
                    matrix = (
                        matrix
                        + jnp.swapaxes(matrix, -2, -1)
                        - jnp.diag(jnp.diag(matrix))
                    )
                return transform.inv(matrix)

            expected = np.linalg.slogdet(jax.jacobian(vec_transform)(x))[1]
            inv_expected = np.linalg.slogdet(jax.jacobian(inv_vec_transform)(y_tril))[1]
        else:
            expected = jnp.log(jnp.abs(grad(transform)(x)))
            inv_expected = jnp.log(jnp.abs(grad(transform.inv)(y)))

        assert_allclose(actual, expected, atol=1e-5, rtol=1e-5)
        assert_allclose(actual, -inv_expected, atol=1e-5, rtol=1e-5)


# NB: skip transforms which are tested in `test_biject_to`
@pytest.mark.parametrize(
    "transform, event_shape",
    [
        (PermuteTransform(np.array([3, 0, 4, 1, 2])), (5,)),
        (PowerTransform(2.0), ()),
        (SoftplusTransform(), ()),
        (
            LowerCholeskyAffine(
                np.array([1.0, 2.0]), np.array([[0.6, 0.0], [1.5, 0.4]])
            ),
            (2,),
        ),
        (
            transforms.ComposeTransform(
                [
                    biject_to(constraints.simplex),
                    SimplexToOrderedTransform(0.0),
                    biject_to(constraints.ordered_vector).inv,
                ]
            ),
            (5,),
        ),
    ],
)
@pytest.mark.parametrize(
    "batch_shape",
    [
        (),
        (1,),
        (3,),
        (6,),
        (3, 1),
        (1, 3),
        (5, 3),
    ],
)
def test_bijective_transforms(transform, event_shape, batch_shape):
    shape = batch_shape + event_shape
    rng_key = random.PRNGKey(0)
    x = biject_to(transform.domain)(random.normal(rng_key, shape))
    y = transform(x)

    # test codomain
    assert_array_equal(transform.codomain(y), jnp.ones(batch_shape))

    # test inv
    z = transform.inv(y)
    assert_allclose(x, z, atol=1e-6, rtol=1e-4)
    assert transform.inv.inv is transform
    assert transform.inv is transform.inv
    assert transform.domain is transform.inv.codomain
    assert transform.codomain is transform.inv.domain

    # test domain
    assert_array_equal(transform.domain(z), jnp.ones(batch_shape))

    # test log_abs_det_jacobian
    actual = transform.log_abs_det_jacobian(x, y)
    assert_allclose(actual, -transform.inv.log_abs_det_jacobian(y, x))
    assert jnp.shape(actual) == batch_shape
    if len(shape) == transform.domain.event_dim:
        if len(event_shape) == 1:
            expected = np.linalg.slogdet(jax.jacobian(transform)(x))[1]
            inv_expected = np.linalg.slogdet(jax.jacobian(transform.inv)(y))[1]
        else:
            expected = jnp.log(jnp.abs(grad(transform)(x)))
            inv_expected = jnp.log(jnp.abs(grad(transform.inv)(y)))
        assert_allclose(actual, expected, atol=1e-5)
        assert_allclose(actual, -inv_expected, atol=1e-5)


@pytest.mark.parametrize("batch_shape", [(), (5,)])
def test_composed_transform(batch_shape):
    t1 = transforms.AffineTransform(0, 2)
    t2 = transforms.LowerCholeskyTransform()
    t = transforms.ComposeTransform([t1, t2, t1])
    assert t.domain.event_dim == 1
    assert t.codomain.event_dim == 2

    x = np.random.normal(size=batch_shape + (6,))
    y = t(x)
    log_det = t.log_abs_det_jacobian(x, y)
    assert log_det.shape == batch_shape
    expected_log_det = (
        jnp.log(2) * 6 + t2.log_abs_det_jacobian(x * 2, y / 2) + jnp.log(2) * 9
    )
    assert_allclose(log_det, expected_log_det, rtol=1e-6)


@pytest.mark.parametrize("batch_shape", [(), (5,)])
def test_composed_transform_1(batch_shape):
    t1 = transforms.AffineTransform(0, 2)
    t2 = transforms.LowerCholeskyTransform()
    t = transforms.ComposeTransform([t1, t2, t2])
    assert t.domain.event_dim == 1
    assert t.codomain.event_dim == 3

    x = np.random.normal(size=batch_shape + (6,))
    y = t(x)
    log_det = t.log_abs_det_jacobian(x, y)
    assert log_det.shape == batch_shape
    z = t2(x * 2)
    expected_log_det = (
        jnp.log(2) * 6
        + t2.log_abs_det_jacobian(x * 2, z)
        + t2.log_abs_det_jacobian(z, t2(z)).sum(-1)
    )
    assert_allclose(log_det, expected_log_det)


@pytest.mark.parametrize("batch_shape", [(), (5,)])
def test_simplex_to_order_transform(batch_shape):
    simplex = jnp.arange(5.0) / jnp.arange(5.0).sum()
    simplex = jnp.broadcast_to(simplex, batch_shape + simplex.shape)
    transform = SimplexToOrderedTransform()
    out = transform(simplex)
    assert out.shape == transform.forward_shape(simplex.shape)
    assert simplex.shape == transform.inverse_shape(out.shape)


@pytest.mark.parametrize("batch_shape", [(), (5,)])
@pytest.mark.parametrize("prepend_event_shape", [(), (4,)])
@pytest.mark.parametrize("sample_shape", [(), (7,)])
def test_transformed_distribution(batch_shape, prepend_event_shape, sample_shape):
    base_dist = (
        dist.Normal(0, 1)
        .expand(batch_shape + prepend_event_shape + (6,))
        .to_event(1 + len(prepend_event_shape))
    )
    t1 = transforms.AffineTransform(0, 2)
    t2 = transforms.LowerCholeskyTransform()
    d = dist.TransformedDistribution(base_dist, [t1, t2, t1])
    assert d.event_dim == 2 + len(prepend_event_shape)

    y = d.sample(random.PRNGKey(0), sample_shape)
    t = transforms.ComposeTransform([t1, t2, t1])
    x = t.inv(y)
    assert x.shape == sample_shape + base_dist.shape()
    log_prob = d.log_prob(y)
    assert log_prob.shape == sample_shape + batch_shape
    t_log_det = t.log_abs_det_jacobian(x, y)
    if prepend_event_shape:
        t_log_det = t_log_det.sum(-1)
    expected_log_prob = base_dist.log_prob(x) - t_log_det
    assert_allclose(log_prob, expected_log_prob, atol=1e-5)


@pytest.mark.parametrize(
    "transformed_dist",
    [
        lambda: dist.TransformedDistribution(
            dist.Normal(np.array([2.0, 3.0]), 1.0), transforms.ExpTransform()
        ),
        lambda: dist.TransformedDistribution(
            dist.Exponential(jnp.ones(2)),
            [
                transforms.PowerTransform(0.7),
                transforms.AffineTransform(0.0, jnp.ones(2) * 3),
            ],
        ),
    ],
)
def test_transformed_distribution_intermediates(transformed_dist):
    transformed_dist = transformed_dist()
    sample, intermediates = transformed_dist.sample_with_intermediates(
        random.PRNGKey(1)
    )
    assert_allclose(
        transformed_dist.log_prob(sample, intermediates),
        transformed_dist.log_prob(sample),
    )


def test_transformed_transformed_distribution():
    loc, scale = -2, 3
    dist1 = dist.TransformedDistribution(
        dist.Normal(2, 3), transforms.PowerTransform(2.0)
    )
    dist2 = dist.TransformedDistribution(dist1, transforms.AffineTransform(-2, 3))
    assert isinstance(dist2.base_dist, dist.Normal)
    assert len(dist2.transforms) == 2
    assert isinstance(dist2.transforms[0], transforms.PowerTransform)
    assert isinstance(dist2.transforms[1], transforms.AffineTransform)

    rng_key = random.PRNGKey(0)
    assert_allclose(loc + scale * dist1.sample(rng_key), dist2.sample(rng_key))
    intermediates = dist2.sample_with_intermediates(rng_key)
    assert len(intermediates) == 2


def _make_iaf(input_dim, hidden_dims, rng_key):
    arn_init, arn = AutoregressiveNN(input_dim, hidden_dims, param_dims=[1, 1])
    _, init_params = arn_init(rng_key, (input_dim,))
    return InverseAutoregressiveTransform(partial(arn, init_params))


@pytest.mark.parametrize(
    "ts",
    [
        lambda: [transforms.PowerTransform(0.7), transforms.AffineTransform(2.0, 3.0)],
        lambda: [transforms.ExpTransform()],
        lambda: [
            transforms.ComposeTransform(
                [transforms.AffineTransform(-2, 3), transforms.ExpTransform()]
            ),
            transforms.PowerTransform(3.0),
        ],
        lambda: [
            _make_iaf(5, hidden_dims=[10], rng_key=random.PRNGKey(0)),
            transforms.PermuteTransform(jnp.arange(5)[::-1]),
            _make_iaf(5, hidden_dims=[10], rng_key=random.PRNGKey(1)),
        ],
    ],
)
def test_compose_transform_with_intermediates(ts):
    transform = transforms.ComposeTransform(ts())
    x = random.normal(random.PRNGKey(2), (7, 5))
    y, intermediates = transform.call_with_intermediates(x)
    logdet = transform.log_abs_det_jacobian(x, y, intermediates)
    assert_allclose(y, transform(x))
    assert_allclose(logdet, transform.log_abs_det_jacobian(x, y))


@pytest.mark.parametrize("x_dim, y_dim", [(3, 3), (3, 4)])
def test_unpack_transform(x_dim, y_dim):
    xy = np.random.randn(x_dim + y_dim)
    unpack_fn = lambda xy: {"x": xy[:x_dim], "y": xy[x_dim:]}  # noqa: E731
    pack_fn = lambda d: jnp.concatenate([d["x"], d["y"]], axis=-1)  # noqa: E731
    transform = transforms.UnpackTransform(unpack_fn, pack_fn)
    z = transform(xy)
    if x_dim == y_dim:
        with pytest.warns(UserWarning, match="UnpackTransform.inv"):
            t = transform.inv(z)
    else:
        t = transform.inv(z)

    assert_allclose(t, xy)


@pytest.mark.parametrize("jax_dist, sp_dist, params", CONTINUOUS)
def test_generated_sample_distribution(jax_dist, sp_dist, params, N_sample=100_000):
    """On samplers that we do not get directly from JAX, (e.g. we only get
    Gumbel(0,1) but also provide samplers for Gumbel(loc, scale)), also test
    agreement in the empirical distribution of generated samples between our
    samplers and those from SciPy.
    """
    key = random.PRNGKey(11)

    if jax_dist not in [dist.Gumbel]:
        pytest.skip(
            "{} sampling method taken from upstream, no need to "
            "test generated samples.".format(jax_dist.__name__)
        )

    jax_dist = jax_dist(*params)
    if sp_dist and not jax_dist.event_shape and not jax_dist.batch_shape:
        our_samples = jax_dist.sample(key, (N_sample,))
        ks_result = osp.kstest(our_samples, sp_dist(*params).cdf)
        assert ks_result.pvalue > 0.05


@pytest.mark.parametrize(
    "jax_dist, params, support",
    [
        (dist.BernoulliLogits, (5.0,), np.arange(2)),
        (dist.BernoulliProbs, (0.5,), np.arange(2)),
        (dist.BinomialLogits, (4.5, 10), np.arange(11)),
        (dist.BinomialProbs, (0.5, 11), np.arange(12)),
        (dist.BetaBinomial, (2.0, 0.5, 12), np.arange(13)),
        (dist.CategoricalLogits, (np.array([3.0, 4.0, 5.0]),), np.arange(3)),
        (dist.CategoricalProbs, (np.array([0.1, 0.5, 0.4]),), np.arange(3)),
        (dist.DiscreteUniform, (2, 4), np.arange(2, 5)),
    ],
)
@pytest.mark.parametrize("batch_shape", [(5,), ()])
@pytest.mark.parametrize("expand", [False, True])
def test_enumerate_support_smoke(jax_dist, params, support, batch_shape, expand):
    p0 = jnp.broadcast_to(params[0], batch_shape + jnp.shape(params[0]))
    actual = jax_dist(p0, *params[1:]).enumerate_support(expand=expand)
    expected = support.reshape((-1,) + (1,) * len(batch_shape))
    if expand:
        expected = jnp.broadcast_to(expected, support.shape + batch_shape)
    assert_allclose(actual, expected)


def test_zero_inflated_enumerate_support():
    base_dist = dist.Bernoulli(0.5)
    d = dist.ZeroInflatedDistribution(base_dist, gate=0.5)
    assert d.has_enumerate_support
    assert_allclose(d.enumerate_support(), base_dist.enumerate_support())


@pytest.mark.parametrize("jax_dist, sp_dist, params", CONTINUOUS + DISCRETE)
@pytest.mark.parametrize("prepend_shape", [(), (2, 3)])
@pytest.mark.parametrize("sample_shape", [(), (4,)])
def test_expand(jax_dist, sp_dist, params, prepend_shape, sample_shape):
    jax_dist = jax_dist(*params)
    new_batch_shape = prepend_shape + jax_dist.batch_shape
    expanded_dist = jax_dist.expand(new_batch_shape)
    rng_key = random.PRNGKey(0)
    samples = expanded_dist.sample(rng_key, sample_shape)
    assert expanded_dist.batch_shape == new_batch_shape
    if isinstance(jax_dist, dist.IntervalCensoredDistribution):
        # interval censored distributions take interval (lo-hi) input but return univarite samples
        samples = jnp.stack([samples, samples + 0.1], axis=-1)
    assert jnp.shape(samples) == sample_shape + new_batch_shape + jax_dist.event_shape
    assert expanded_dist.log_prob(samples).shape == sample_shape + new_batch_shape
    # test expand of expand
    assert (
        expanded_dist.expand((3,) + new_batch_shape).batch_shape
        == (3,) + new_batch_shape
    )
    # test expand error
    if prepend_shape:
        with pytest.raises(ValueError, match="Cannot broadcast distribution of shape"):
            assert expanded_dist.expand((3,) + jax_dist.batch_shape)


@pytest.mark.parametrize("base_shape", [(2, 1, 5), (3, 1), (2, 1, 1), (1, 1, 5)])
@pytest.mark.parametrize("event_dim", [0, 1, 2, 3])
@pytest.mark.parametrize("sample_shape", [(1000,), (1000, 7, 1), (1000, 1, 7)])
def test_expand_shuffle_regression(base_shape, event_dim, sample_shape):
    expand_shape = (2, 3, 5)
    event_dim = min(event_dim, len(base_shape))
    loc = random.normal(random.PRNGKey(0), base_shape) * 10
    base_dist = dist.Normal(loc, 0.1).to_event(event_dim)
    expanded_dist = base_dist.expand(expand_shape[: len(expand_shape) - event_dim])
    samples = expanded_dist.sample(random.PRNGKey(1), sample_shape)
    expected_mean = jnp.broadcast_to(loc, sample_shape[1:] + expanded_dist.shape())
    assert_allclose(samples.mean(0), expected_mean, atol=0.1)


@pytest.mark.parametrize("batch_shape", [(), (4,), (10, 3)])
def test_sine_bivariate_von_mises_batch_shape(batch_shape):
    phi_loc = jnp.broadcast_to(jnp.array(0.0), batch_shape)
    psi_loc = jnp.array(0.0)
    phi_conc = jnp.array(1.0)
    psi_conc = jnp.array(1.0)
    corr = jnp.array(0.1)

    sine = SineBivariateVonMises(phi_loc, psi_loc, phi_conc, psi_conc, corr)
    assert sine.batch_shape == batch_shape

    samples = sine.sample(random.PRNGKey(0))
    assert samples.shape == (*batch_shape, 2)


def test_sine_bivariate_von_mises_sample_mean():
    loc = jnp.array([[2.0, -1.0], [-2, 1.0]])

    sine = SineBivariateVonMises(*loc, 5000, 5000, 0.0)
    samples = sine.sample(random.PRNGKey(0), (5000,))

    assert_allclose(_circ_mean(samples).T, loc, rtol=5e-3)


@pytest.mark.parametrize("batch_shape", [(), (4,)])
def test_polya_gamma(batch_shape, num_points=20000):
    d = dist.TruncatedPolyaGamma(batch_shape=batch_shape)
    rng_key = random.PRNGKey(0)

    # test density approximately normalized
    x = jnp.linspace(1.0e-6, d.truncation_point, num_points)
    prob = (d.truncation_point / num_points) * jnp.exp(
        logsumexp(d.log_prob(x), axis=-1)
    )
    assert_allclose(prob, jnp.ones(batch_shape), rtol=1.0e-4)

    # test mean of approximate sampler
    z = d.sample(rng_key, sample_shape=(3000,))
    mean = jnp.mean(z, axis=-1)
    assert_allclose(mean, 0.25 * jnp.ones(batch_shape), rtol=0.07)


@pytest.mark.parametrize(
    "extra_event_dims,expand_shape",
    [(0, (4, 3, 2, 1)), (0, (4, 3, 2, 2)), (1, (5, 4, 3, 2)), (2, (5, 4, 3))],
)
def test_expand_reshaped_distribution(extra_event_dims, expand_shape):
    loc = jnp.zeros((1, 6))
    scale_tril = jnp.eye(6)
    d = dist.MultivariateNormal(loc, scale_tril=scale_tril)
    full_shape = (4, 1, 1, 1, 6)
    reshaped_dist = d.expand([4, 1, 1, 1]).to_event(extra_event_dims)
    cut = 4 - extra_event_dims
    batch_shape, event_shape = full_shape[:cut], full_shape[cut:]
    assert reshaped_dist.batch_shape == batch_shape
    assert reshaped_dist.event_shape == event_shape
    large = reshaped_dist.expand(expand_shape)
    assert large.batch_shape == expand_shape
    assert large.event_shape == event_shape

    # Throws error when batch shape cannot be broadcasted
    with pytest.raises((RuntimeError, ValueError)):
        reshaped_dist.expand(expand_shape + (3,))

    # Throws error when trying to shrink existing batch shape
    with pytest.raises((RuntimeError, ValueError)):
        large.expand(expand_shape[1:])


@pytest.mark.parametrize(
    "batch_shape, mask_shape",
    [((), ()), ((2,), ()), ((), (2,)), ((2,), (2,)), ((4, 2), (1, 2)), ((2,), (4, 2))],
)
@pytest.mark.parametrize("event_shape", [(), (3,)])
def test_mask(batch_shape, event_shape, mask_shape):
    jax_dist = (
        dist.Normal().expand(batch_shape + event_shape).to_event(len(event_shape))
    )
    mask = dist.Bernoulli(0.5).sample(random.PRNGKey(0), mask_shape)
    if mask_shape == ():
        mask = bool(mask)
    samples = jax_dist.sample(random.PRNGKey(1))
    actual = jax_dist.mask(mask).log_prob(samples)
    assert_allclose(
        actual != 0,
        jnp.broadcast_to(mask, lax.broadcast_shapes(batch_shape, mask_shape)),
    )


@pytest.mark.parametrize("event_shape", [(), (4,), (2, 4)])
def test_mask_grad(event_shape):
    def f(x, data):
        base_dist = dist.Beta(jnp.exp(x), jnp.ones(event_shape)).to_event()
        mask = jnp.all(
            jnp.isfinite(data), tuple(-i - 1 for i in range(len(event_shape)))
        )
        log_prob = base_dist.mask(mask).log_prob(data)
        assert log_prob.shape == data.shape[: len(data.shape) - len(event_shape)]
        return log_prob.sum()

    data = np.array([[0.4, np.nan, 0.2, np.nan], [0.5, 0.5, 0.5, 0.5]])
    log_prob, grad = jax.value_and_grad(f)(1.0, data)
    assert jnp.isfinite(grad) and jnp.isfinite(log_prob)


@pytest.mark.parametrize(
    "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
def test_dist_pytree(jax_dist, sp_dist, params):
    def f(x):
        return jax_dist(*params)

    if jax_dist is _ImproperWrapper:
        pytest.skip("Cannot flattening ImproperUniform")
    if jax_dist is dist.EulerMaruyama:
        pytest.skip("EulerMaruyama doesn't define flatten/unflatten")
    jax.jit(f)(0)  # this test for flatten/unflatten
    lax.map(f, np.ones(3))  # this test for compatibility w.r.t. scan
    # Test that parameters do not change after flattening.
    expected_dist = f(0)
    actual_dist = jax.jit(f)(0)
    for name in expected_dist.arg_constraints:
        expected_arg = getattr(expected_dist, name)
        actual_arg = getattr(actual_dist, name)
        assert actual_arg is not None, f"arg {name} is None"
        if np.issubdtype(np.asarray(expected_arg).dtype, np.number):
            assert_allclose(actual_arg, expected_arg, atol=1e-7)
        else:
            assert (
                actual_arg.shape == expected_arg.shape
                and actual_arg.dtype == expected_arg.dtype
            )
    expected_sample = expected_dist.sample(random.PRNGKey(0))
    actual_sample = actual_dist.sample(random.PRNGKey(0))
    if isinstance(expected_dist, dist.IntervalCensoredDistribution):
        # interval censored distributions take interval (lo-hi) input but return univarite samples
        expected_sample = jnp.stack([expected_sample, expected_sample + 0.1], axis=-1)
        actual_sample = jnp.stack([actual_sample, actual_sample + 0.1], axis=-1)
    expected_log_prob = expected_dist.log_prob(expected_sample)
    actual_log_prob = actual_dist.log_prob(actual_sample)
    assert_allclose(actual_sample, expected_sample, rtol=1e-6)
    assert_allclose(actual_log_prob, expected_log_prob, rtol=1e-5)


@pytest.mark.parametrize(
    "method, arg", [("to_event", 1), ("mask", False), ("expand", [5])]
)
def test_special_dist_pytree(method, arg):
    def f(x):
        d = dist.Normal(np.zeros(1), np.ones(1))
        return getattr(d, method)(arg)

    jax.jit(f)(0)
    lax.map(f, np.ones(3))


def test_expand_no_unnecessary_batch_shape_expansion():
    # ExpandedDistribution can mutate the `batch_shape` of
    # its base distribution in order to make ExpandedDistribution
    # mappable, see #684. However, this mutation should not take
    # place if no mapping operation is performed.

    for arg in (jnp.array(1.0), jnp.ones((2,)), jnp.ones((2, 2))):
        # Low level test: ensure that (tree_flatten o tree_unflatten)(expanded_dist)
        # amounts to an identity operation.
        d = dist.Normal(arg, arg).expand([10, 3, *arg.shape])
        roundtripped_d = type(d).tree_unflatten(*d.tree_flatten()[::-1])
        assert d.batch_shape == roundtripped_d.batch_shape
        assert d.base_dist.batch_shape == roundtripped_d.base_dist.batch_shape
        assert d.base_dist.event_shape == roundtripped_d.base_dist.event_shape
        assert jnp.allclose(d.base_dist.loc, roundtripped_d.base_dist.loc)
        assert jnp.allclose(d.base_dist.scale, roundtripped_d.base_dist.scale)

        # High-level test: `jax.jit`ting a function returning an ExpandedDistribution
        # (which involves an instance of the low-level case as it will transform
        #  the original function by adding some flattening and unflattening steps)
        # should return same object as its non-jitted equivalent.
        def bs(arg):
            return dist.Normal(arg, arg).expand([10, 3, *arg.shape])

        d = bs(arg)
        dj = jax.jit(bs)(arg)

        assert isinstance(d, dist.ExpandedDistribution)
        assert isinstance(dj, dist.ExpandedDistribution)

        assert d.batch_shape == dj.batch_shape
        assert d.base_dist.batch_shape == dj.base_dist.batch_shape
        assert d.base_dist.event_shape == dj.base_dist.event_shape
        assert jnp.allclose(d.base_dist.loc, dj.base_dist.loc)
        assert jnp.allclose(d.base_dist.scale, dj.base_dist.scale)


@pytest.mark.parametrize("batch_shape", [(), (4,), (2, 3)], ids=str)
def test_kl_delta_normal_shape(batch_shape):
    v = np.random.normal(size=batch_shape)
    loc = np.random.normal(size=batch_shape)
    scale = np.exp(np.random.normal(size=batch_shape))
    p = dist.Delta(v)
    q = dist.Normal(loc, scale)
    assert kl_divergence(p, q).shape == batch_shape


def test_kl_delta_normal():
    v = np.random.normal()
    loc = np.random.normal()
    scale = np.exp(np.random.normal())
    p = dist.Delta(v, 10.0)
    q = dist.Normal(loc, scale)
    assert_allclose(kl_divergence(p, q), 10.0 - q.log_prob(v))


@pytest.mark.parametrize("batch_shape", [(), (4,), (2, 3)], ids=str)
@pytest.mark.parametrize("event_shape", [(), (4,), (2, 3)], ids=str)
def test_kl_independent_normal(batch_shape, event_shape):
    shape = batch_shape + event_shape
    p = dist.Normal(np.random.normal(size=shape), np.exp(np.random.normal(size=shape)))
    q = dist.Normal(np.random.normal(size=shape), np.exp(np.random.normal(size=shape)))
    actual = kl_divergence(
        dist.Independent(p, len(event_shape)), dist.Independent(q, len(event_shape))
    )
    expected = sum_rightmost(kl_divergence(p, q), len(event_shape))
    assert_allclose(actual, expected)


@pytest.mark.parametrize("batch_shape", [(), (4,), (2, 3)], ids=str)
@pytest.mark.parametrize("event_shape", [(), (4,), (2, 3)], ids=str)
def test_kl_expanded_normal(batch_shape, event_shape):
    shape = batch_shape + event_shape
    p = dist.Normal(np.random.normal(), np.exp(np.random.normal())).expand(shape)
    q = dist.Normal(np.random.normal(), np.exp(np.random.normal())).expand(shape)
    actual = kl_divergence(
        dist.Independent(p, len(event_shape)), dist.Independent(q, len(event_shape))
    )
    expected = sum_rightmost(kl_divergence(p, q), len(event_shape))
    assert_allclose(actual, expected)


@pytest.mark.parametrize(
    "batch_shape_p, batch_shape_q",
    [
        ((1,), (1,)),
        ((2, 3), (2, 3)),
        ((5, 1, 3), (2, 3)),
        ((1, 3), (5, 2, 3)),
    ],
    ids=str,
)
@pytest.mark.parametrize("single_scale_p", [False, True], ids=str)
@pytest.mark.parametrize("single_loc_p", [False, True], ids=str)
@pytest.mark.parametrize("single_scale_q", [False, True], ids=str)
@pytest.mark.parametrize("single_loc_q", [False, True], ids=str)
def test_kl_multivariate_normal_consistency_with_independent_normals(
    batch_shape_p,
    batch_shape_q,
    single_scale_p,
    single_loc_p,
    single_scale_q,
    single_loc_q,
):
    event_shape = (5,)

    def make_dists(loc_batch_shape, scales_batch_shape):
        mus = np.random.normal(size=loc_batch_shape + event_shape)
        scales = np.exp(np.random.normal(size=scales_batch_shape + event_shape) * 0.1)

        def diagonalize(v, ignore_axes: int):
            if ignore_axes == 0:
                return jnp.diag(v)
            return vmap(diagonalize, in_axes=(0, None))(v, ignore_axes - 1)

        scale_tril = diagonalize(scales, len(scales_batch_shape))
        return (
            dist.Normal(mus, scales).to_event(len(event_shape)),
            dist.MultivariateNormal(mus, scale_tril=scale_tril),
        )

    p_uni, p_mvn = make_dists(
        () if single_loc_p else batch_shape_p, () if single_scale_p else batch_shape_p
    )
    q_uni, q_mvn = make_dists(
        () if single_loc_q else batch_shape_q, () if single_scale_q else batch_shape_q
    )

    actual = kl_divergence(p_mvn, q_mvn)
    expected = kl_divergence(p_uni, q_uni)
    assert_allclose(actual, expected, atol=1e-5)


def test_kl_multivariate_normal_nondiagonal_covariance():
    p_mvn = dist.MultivariateNormal(np.zeros(2), covariance_matrix=np.eye(2))
    q_mvn = dist.MultivariateNormal(
        np.ones(2), covariance_matrix=np.array([[2, 0.8], [0.8, 0.5]])
    )

    actual = kl_divergence(p_mvn, q_mvn)
    expected = 3.21138
    assert_allclose(actual, expected, atol=2e-5)


@pytest.mark.parametrize("shape", [(), (4,), (2, 3)], ids=str)
@pytest.mark.parametrize(
    "p_dist, q_dist",
    [
        (dist.Beta, dist.Beta),
        (dist.Gamma, dist.Gamma),
        (dist.Kumaraswamy, dist.Beta),
        (dist.Normal, dist.Normal),
        (dist.Weibull, dist.Gamma),
    ],
)
def test_kl_univariate(shape, p_dist, q_dist):
    def make_dist(dist_class):
        params = {}
        for k, c in dist_class.arg_constraints.items():
            if c is constraints.real:
                params[k] = np.random.normal(size=shape)
            elif c is constraints.positive:
                params[k] = np.exp(np.random.normal(size=shape))
            else:
                raise ValueError(f"Missing pattern for param {k}.")
        d = dist_class(**params)
        if dist_class is dist.Kumaraswamy:
            d.KL_KUMARASWAMY_BETA_TAYLOR_ORDER = 1000
        return d

    p = make_dist(p_dist)
    q = make_dist(q_dist)
    actual = kl_divergence(p, q)
    x = p.sample(random.PRNGKey(0), (10000,)).copy()
    expected = jnp.mean((p.log_prob(x) - q.log_prob(x)), 0)
    assert_allclose(actual, expected, rtol=0.05)


@pytest.mark.parametrize("shape", [(3, 2, 10), (3, 2, 11), (10,), (11,)], ids=str)
def test_kl_circulant_normal_consistency(shape: tuple) -> None:
    key1, key2, key3, key4 = random.split(random.key(9), 4)
    p = dist.Normal(random.normal(key1, shape), random.gamma(key2, 3, shape)).to_event(
        1
    )
    # covariance_rfft = jnp.exp(-jnp.arange(shape[-1] // 2 + 1))
    covariance_rfft = random.gamma(key4, 10, shape[:-1] + (shape[-1] // 2 + 1,)) / 10
    q = dist.CirculantNormal(
        random.normal(key3, shape), covariance_rfft=covariance_rfft
    )
    actual = kl_divergence(p, q)
    expected = kl_divergence(
        dist.MultivariateNormal(p.mean, jnp.eye(shape[-1]) * p.variance[..., None]),
        dist.MultivariateNormal(q.mean, q.covariance_matrix),
    )
    assert_allclose(actual, expected, rtol=1e-6)


@pytest.mark.parametrize("shape", [(4,), (2, 3)], ids=str)
def test_kl_dirichlet_dirichlet(shape):
    p = dist.Dirichlet(np.exp(np.random.normal(size=shape)))
    q = dist.Dirichlet(np.exp(np.random.normal(size=shape)))
    actual = kl_divergence(p, q)
    x = p.sample(random.PRNGKey(0), (10_000,)).copy()
    expected = jnp.mean((p.log_prob(x) - q.log_prob(x)), 0)
    assert_allclose(actual, expected, rtol=0.05)


def test_vmapped_binomial_p0():
    # test that vmapped binomial with p = 0 does not have an infinite loop
    def sample_binomial_withp0(key):
        n = 2 * (random.uniform(key) > 0.5)
        _, key = random.split(key)
        return dist.Binomial(total_count=n, probs=0).sample(key)

    jax.vmap(sample_binomial_withp0)(random.split(random.PRNGKey(0), 1))


def _get_vmappable_dist_init_params(jax_dist):
    if jax_dist.__name__ == ("_TruncatedCauchy"):
        return [2, 3]
    elif jax_dist.__name__ == ("_TruncatedNormal"):
        return [2, 3]
    elif jax_dist.__name__ == ("_LeftCensoredHalfNormal"):
        return [1]
    elif jax_dist.__name__ == ("_RightCensoredWeibull"):
        return [2]
    elif jax_dist.__name__ == ("_LeftCensoredNormal"):
        return [2]
    elif jax_dist.__name__ == ("_RightCensoredNormal"):
        return [2]
    elif jax_dist.__name__ == ("_LeftCensoredPoisson"):
        return [1]
    elif jax_dist.__name__ == ("_RightCensoredPoisson"):
        return [1]
    elif jax_dist.__name__ == ("_IntervalCensoredNormal"):
        return [2, 3]
    elif jax_dist.__name__ == ("_IntervalCensoredWeibull"):
        return [2, 3]
    elif jax_dist.__name__ == ("_IntervalCensoredPoisson"):
        return [1, 2]
    elif issubclass(jax_dist, dist.Distribution):
        init_parameters = list(inspect.signature(jax_dist.__init__).parameters.keys())[
            1:
        ]
        vmap_over_parameters = list(
            inspect.signature(vmap_over.dispatch(jax_dist)).parameters.keys()
        )[1:]
        return list(
            [
                i
                for i, name in enumerate(init_parameters)
                if name in vmap_over_parameters
            ]
        )
    else:
        raise ValueError


def _allclose_or_equal(a1, a2):
    if isinstance(a1, np.ndarray):
        return np.allclose(a2, a1)
    elif isinstance(a1, jnp.ndarray):
        return jnp.allclose(a2, a1)
    elif isinstance(a1, csr_matrix):
        return np.allclose(a2.todense(), a1.todense())
    else:
        return a2 == a1 or a2 is a1


def _tree_equal(t1, t2):
    t = jax.tree.map(_allclose_or_equal, t1, t2)
    return jnp.all(jax.flatten_util.ravel_pytree(t)[0])


@pytest.mark.parametrize(
    "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
def test_vmap_dist(jax_dist, sp_dist, params):
    param_names = list(inspect.signature(jax_dist).parameters.keys())
    vmappable_param_idxs = _get_vmappable_dist_init_params(jax_dist)
    vmappable_param_idxs = vmappable_param_idxs[: len(params)]

    if len(vmappable_param_idxs) == 0:
        return

    def make_jax_dist(*params):
        return jax_dist(*params)

    def sample(d: dist.Distribution):
        return d.sample(random.PRNGKey(0))

    d = make_jax_dist(*params)

    if isinstance(d, _SparseCAR) and d.is_sparse:
        # In this case, since csr arrays are not jittable,
        # _SparseCAR has a csr_matrix as part of its pytree
        # definition (not as a pytree leaf). This causes pytree
        # operations like jax.tree.map to fail, since these functions
        # compare the pytree def of each of the arguments using ==
        # which is ambiguous for array-like objects.
        return

    in_out_axes_cases = [
        # vmap over all args
        (
            tuple(0 if i in vmappable_param_idxs else None for i in range(len(params))),
            0,
        ),
        # vmap over a single arg, out over all attributes of a distribution
        *(
            ([0 if i == idx else None for i in range(len(params))], 0)
            for idx in vmappable_param_idxs
            if params[idx] is not None
        ),
        # vmap over a single arg, out over the associated attribute of the distribution
        *(
            (
                [0 if i == idx else None for i in range(len(params))],
                vmap_over(d, **{param_names[idx]: 0}),
            )
            for idx in vmappable_param_idxs
            if params[idx] is not None
        ),
        # vmap over a single arg, axis=1, (out single attribute, axis=1)
        *(
            (
                [1 if i == idx else None for i in range(len(params))],
                vmap_over(d, **{param_names[idx]: 1}),
            )
            for idx in vmappable_param_idxs
            if isinstance(params[idx], jnp.ndarray)
            and jnp.array(params[idx]).ndim > 0
            # skip this distribution because _GeneralMixture.__init__ turns
            # 1d inputs into 0d attributes, thus breaks the expectations of
            # the vmapping test case where in_axes=1, only done for rank>=1 tensors.
            and jax_dist is not _GeneralMixture
        ),
    ]

    for in_axes, out_axes in in_out_axes_cases:
        batched_params = [
            (
                jax.jax.tree.map(lambda x: jnp.expand_dims(x, ax), arg)
                if isinstance(ax, int)
                else arg
            )
            for arg, ax in zip(params, in_axes)
        ]
        # Recreate the jax_dist to avoid side effects coming from `d.sample`
        # triggering lazy_property computations, which, in a few cases, break
        # vmap_over's expectations regarding existing attributes to be vmapped.
        d = make_jax_dist(*params)
        batched_d = jax.vmap(make_jax_dist, in_axes=in_axes, out_axes=out_axes)(
            *batched_params
        )
        eq = vmap(lambda x, y: _tree_equal(x, y), in_axes=(out_axes, None))(
            batched_d, d
        )
        assert eq == jnp.array([True])

        samples_dist = sample(d)
        samples_batched_dist = jax.vmap(sample, in_axes=(out_axes,))(batched_d)
        assert samples_batched_dist.shape == (1, *samples_dist.shape)


def test_vmap_validate_args():
    # Test for #1684: vmapping distributions would work when `validate_args=True`
    v_dist = jax.vmap(
        lambda loc, scale: dist.Normal(loc=loc, scale=scale, validate_args=True),
        in_axes=(0, 0),
    )(jnp.zeros((2,)), jnp.zeros((2,)))

    # non-regression test
    v_dist = jax.vmap(
        lambda loc, scale: dist.Normal(loc=loc, scale=scale, validate_args=False),
        in_axes=(0, 0),
    )(jnp.zeros((2,)), jnp.zeros((2,)))
    assert not v_dist._validate_args


def test_explicit_validate_args():
    # Check validation passes for valid parameters.
    d = dist.Normal(0, 1, validate_args=False)
    d.validate_args()

    # Check validation fails for invalid parameters.
    d = dist.Normal(0, -1, validate_args=False)
    with pytest.raises(ValueError, match="got invalid scale parameter"):
        d.validate_args()

    # Check validation is skipped for strict=False and raises an error for strict=True.
    jitted = jax.jit(
        lambda d, strict: d.validate_args(strict), static_argnames=["strict"]
    )
    jitted(d, False)
    with pytest.raises(RuntimeError, match="Cannot validate arguments"):
        jitted(d, True)


def test_get_args():
    # Test that we only pick up parameters that were supplied or derived by the
    # constructor.
    d = dist.MultivariateNormal(precision_matrix=jnp.eye(3))
    assert set(d.get_args()) == {"loc", "precision_matrix", "scale_tril"}


def test_multinomial_abstract_total_count():
    probs = jnp.array([0.2, 0.5, 0.3])
    key = random.PRNGKey(0)

    def f(x):
        total_count = x.sum(-1)
        return dist.Multinomial(total_count, probs=probs, total_count_max=10).sample(
            key
        )

    x = dist.Multinomial(10, probs).sample(key)
    y = jax.jit(f)(x)
    assert_allclose(x, y, rtol=1e-6)


def test_dirichlet_multinomial_abstract_total_count():
    probs = jnp.array([0.2, 0.5, 0.3])
    key = random.PRNGKey(0)

    def f(x):
        total_count = x.sum(-1)
        return dist.DirichletMultinomial(
            concentration=probs,
            total_count=total_count,
            total_count_max=10,  # fails on 0.18.0
        ).sample(key)

    x = dist.DirichletMultinomial(concentration=probs, total_count=10).sample(key)
    y = jax.jit(f)(x)
    assert_allclose(x, y, rtol=1e-6)


def test_normal_log_cdf():
    # test if log_cdf method agrees with jax.scipy.stats.norm.logcdf
    # and if exp(log_cdf) agrees with cdf
    loc = jnp.array([[0.0, -10.0, 20.0]])
    scale = jnp.array([[1, 5, 7]])
    values = jnp.linspace(-5, 5, 100).reshape(-1, 1)
    numpyro_log_cdf = dist.Normal(loc=loc, scale=scale).log_cdf(values)
    numpyro_cdf = dist.Normal(loc=loc, scale=scale).cdf(values)
    jax_log_cdf = jax_norm.logcdf(loc=loc, scale=scale, x=values)
    assert_allclose(numpyro_log_cdf, jax_log_cdf)
    assert_allclose(jnp.exp(numpyro_log_cdf), numpyro_cdf, rtol=1e-6)


@pytest.mark.parametrize(
    "value",
    [
        -15.0,
        np.array([[-15.0], [-10.0], [-5.0]]),
        np.array([[[-15.0], [-10.0], [-5.0]], [[-14.0], [-9.0], [-4.0]]]),
    ],
)
def test_truncated_normal_log_prob_in_tail(value):
    # define set of distributions truncated in tail of distribution
    loc = 1.35
    scale = jnp.geomspace(0.01, 1, 10)
    low, high = (-20, -1.0)
    a, b = (low - loc) / scale, (high - loc) / scale  # rescale for jax input

    numpyro_log_prob = dist.TruncatedNormal(loc, scale, low=low, high=high).log_prob(
        value
    )
    jax_log_prob = jax_truncnorm.logpdf(value, loc=loc, scale=scale, a=a, b=b)
    assert_allclose(numpyro_log_prob, jax_log_prob, rtol=1e-06)


def test_sample_truncated_normal_in_tail():
    # test, if samples from distributions truncated in
    # tail of distribution returns any inf's
    tail_dist = dist.TruncatedNormal(loc=0, scale=1, low=-16, high=-15)
    samples = tail_dist.sample(random.PRNGKey(0), sample_shape=(10_000,))
    assert ~jnp.isinf(samples).any()


@jax.enable_custom_prng()
def test_jax_custom_prng():
    samples = dist.Normal(0, 5).sample(random.PRNGKey(0), sample_shape=(1000,))
    assert ~jnp.isinf(samples).any()


def _assert_not_jax_issue_19885(
    capfd: pytest.CaptureFixture, func: Callable, *args, **kwargs
) -> None:
    # jit-ing identity plus matrix multiplication leads to performance degradation as
    # discussed in https://github.com/jax-ml/jax/issues/19885. This assertion verifies
    # that the issue does not affect performance in numpyro.
    for jit in [True, False]:
        result = jax.jit(func)(*args, **kwargs)
        block_until_ready = getattr(result, "block_until_ready", None)
        if block_until_ready:
            result = block_until_ready()
        _, err = capfd.readouterr()
        assert "MatMul reference implementation being executed" not in err, (
            f"jit: {jit}"
        )
    return result


@pytest.mark.xfail
def test_jax_issue_19885(capfd: pytest.CaptureFixture) -> None:
    def func_with_warning(y) -> jnp.ndarray:
        return jnp.identity(y.shape[-1]) + jnp.matmul(y, y)

    _assert_not_jax_issue_19885(capfd, func_with_warning, jnp.ones((20, 100, 100)))


def test_lowrank_mvn_19885(capfd: pytest.CaptureFixture) -> None:
    # Create parameters.
    batch_size = 100
    event_size = 200
    sample_size = 40
    rank = 40
    loc, cov_diag = random.normal(random.key(0), (2, batch_size, event_size))
    cov_diag = jnp.exp(cov_diag)
    cov_factor = random.normal(random.key(1), (batch_size, event_size, rank))

    distribution = _assert_not_jax_issue_19885(
        capfd, dist.LowRankMultivariateNormal, loc, cov_factor, cov_diag
    )
    x = _assert_not_jax_issue_19885(
        capfd,
        lambda x: distribution.sample(random.key(0), x.shape),
        jnp.empty(sample_size),
    )
    assert x.shape == (sample_size, batch_size, event_size)
    log_prob = _assert_not_jax_issue_19885(capfd, distribution.log_prob, x)
    assert log_prob.shape == (sample_size, batch_size)


def test_gaussian_random_walk_linear_recursive_equivalence():
    dist1 = dist.GaussianRandomWalk(3.7, 15)
    dist2 = dist.TransformedDistribution(
        dist.Normal(0, 3.7).expand([15, 1]).to_event(2),
        dist.transforms.RecursiveLinearTransform(jnp.eye(1)),
    )
    x1 = dist1.sample(random.PRNGKey(7))
    x2 = dist2.sample(random.PRNGKey(7))
    assert jnp.allclose(x1, x2.squeeze())
    assert jnp.allclose(dist1.log_prob(x1), dist2.log_prob(x2))


@pytest.mark.parametrize("conc", [1.0, 10.0, 1000.0, 10000.0])
def test_sine_bivariate_von_mises_norm(conc):
    dist = SineBivariateVonMises(0, 0, conc, conc, 0.0)
    num_samples = 500
    x = jnp.linspace(-jnp.pi, jnp.pi, num_samples)
    y = jnp.linspace(-jnp.pi, jnp.pi, num_samples)
    mesh = jnp.stack(jnp.meshgrid(x, y), axis=-1)
    integral_torus = (
        jnp.exp(dist.log_prob(mesh)) * (2 * jnp.pi) ** 2 / num_samples**2
    ).sum()
    assert jnp.allclose(integral_torus, 1.0, rtol=1e-2)


def test_gaussian_random_walk_state_space_equivalence():
    # Gaussian random walks are state space models with one state and unit transition
    # matrix. Here, we verify we get the expected results.
    scale = 0.3
    num_steps = 4
    d1 = dist.GaussianRandomWalk(scale, num_steps)
    d2 = dist.GaussianStateSpace(num_steps, jnp.eye(1), scale_tril=scale * jnp.eye(1))
    assert jnp.allclose(d1.variance, jnp.squeeze(d2.variance, axis=-1))

    key = jax.random.key(18)
    x1 = d1.sample(key, (3,))
    x2 = d2.sample(key, (3,))
    assert jnp.allclose(x1, jnp.squeeze(x2, axis=-1))

    assert jnp.allclose(d1.log_prob(x1), d2.log_prob(x2))


def test_consistent_pytree() -> None:
    def make_dist():
        return dist.MultivariateNormal(precision_matrix=jnp.eye(2))

    init = make_dist()
    # Access the covariance matrix to evaluate the lazy property.
    init.covariance_matrix
    assert "covariance_matrix" in init.__dict__

    # Run scan which validates that pytree structures are consistent.
    jax.lax.scan(lambda *_: (make_dist(), None), init, jnp.arange(7))


def test_distribution_repr():
    result = repr(dist.Wishart(7, jnp.eye(5)).expand([3, 4]).to_event(1))
    assert "batch shape (3,)" in result
    assert "event shape (4, 5, 5)"


@pytest.mark.parametrize(
    "base_dist_class, base_params",
    [
        (dist.Normal, (0.0, 1.0)),
        (dist.Normal, (2.0, 0.5)),
        (dist.Cauchy, (0.0, 1.0)),
        (dist.Laplace, (0.0, 1.0)),
        (dist.Logistic, (0.0, 1.0)),
        (dist.StudentT, (2.0, 0.0, 1.0)),
    ],
)
@pytest.mark.parametrize("low", [-2.0, -1.0, 0.0])
def test_left_truncated_cdf(base_dist_class, base_params, low):
    """Test CDF for left truncated distributions."""
    base_dist = base_dist_class(*base_params)
    truncated_dist = dist.LeftTruncatedDistribution(base_dist, low)

    # Test points
    test_values = jnp.array([low - 1.0, low, low + 0.5, low + 1.0, low + 2.0])

    # Compute CDF
    cdf_values = truncated_dist.cdf(test_values)

    # Basic properties
    assert cdf_values.shape == test_values.shape
    assert jnp.all(cdf_values >= 0.0)
    assert jnp.all(cdf_values <= 1.0)

    # Values below truncation point should have CDF = 0
    assert_allclose(cdf_values[0], 0.0, atol=1e-6)

    # CDF should be monotonically increasing
    assert jnp.all(jnp.diff(cdf_values[1:]) >= -1e-6)  # Allow small numerical errors

    # Test consistency with icdf (inverse CDF)
    quantiles = jnp.array([0.1, 0.3, 0.5, 0.7, 0.9])
    icdf_values = truncated_dist.icdf(quantiles)
    recovered_quantiles = truncated_dist.cdf(icdf_values)
    assert_allclose(recovered_quantiles, quantiles, atol=1e-5)


@pytest.mark.parametrize(
    "base_dist_class, base_params",
    [
        (dist.Normal, (0.0, 1.0)),
        (dist.Normal, (-1.0, 2.0)),
        (dist.Cauchy, (0.0, 1.0)),
        (dist.Laplace, (0.0, 1.0)),
        (dist.Logistic, (0.0, 1.0)),
        (dist.StudentT, (2.0, 0.0, 1.0)),
    ],
)
@pytest.mark.parametrize("high", [0.0, 1.0, 2.0])
def test_right_truncated_cdf(base_dist_class, base_params, high):
    """Test CDF for right truncated distributions."""
    base_dist = base_dist_class(*base_params)
    truncated_dist = dist.RightTruncatedDistribution(base_dist, high)

    # Test points
    test_values = jnp.array([high - 2.0, high - 1.0, high - 0.5, high, high + 1.0])

    # Compute CDF
    cdf_values = truncated_dist.cdf(test_values)

    # Basic properties
    assert cdf_values.shape == test_values.shape
    assert jnp.all(cdf_values >= 0.0)
    assert jnp.all(cdf_values <= 1.0)

    # Values above truncation point should have CDF = 1
    assert_allclose(cdf_values[-1], 1.0, atol=1e-6)

    # CDF should be monotonically increasing
    assert jnp.all(jnp.diff(cdf_values[:-1]) >= -1e-6)  # Allow small numerical errors

    # Test consistency with icdf (inverse CDF)
    quantiles = jnp.array([0.1, 0.3, 0.5, 0.7, 0.9])
    icdf_values = truncated_dist.icdf(quantiles)
    recovered_quantiles = truncated_dist.cdf(icdf_values)
    assert_allclose(recovered_quantiles, quantiles, atol=1e-5)


@pytest.mark.parametrize(
    "base_dist_class, base_params",
    [
        (dist.Normal, (0.0, 1.0)),
        (dist.Normal, (1.0, 0.8)),
        (dist.Cauchy, (0.0, 1.0)),
        (dist.Laplace, (0.0, 1.0)),
        (dist.Logistic, (0.0, 1.0)),
        (dist.StudentT, (2.0, 0.0, 1.0)),
    ],
)
@pytest.mark.parametrize("low, high", [(-2.0, 2.0), (-1.0, 1.0), (0.0, 3.0)])
def test_two_sided_truncated_cdf(base_dist_class, base_params, low, high):
    """Test CDF for two-sided truncated distributions."""
    base_dist = base_dist_class(*base_params)
    truncated_dist = dist.TwoSidedTruncatedDistribution(base_dist, low, high)

    # Test points
    test_values = jnp.array([low - 1.0, low, (low + high) / 2, high, high + 1.0])

    # Compute CDF
    cdf_values = truncated_dist.cdf(test_values)

    # Basic properties
    assert cdf_values.shape == test_values.shape
    assert jnp.all(cdf_values >= 0.0)
    assert jnp.all(cdf_values <= 1.0)

    # Values below truncation point should have CDF = 0
    assert_allclose(cdf_values[0], 0.0, atol=1e-6)

    # Values above truncation point should have CDF = 1
    assert_allclose(cdf_values[-1], 1.0, atol=1e-6)

    # CDF should be monotonically increasing
    assert jnp.all(jnp.diff(cdf_values[1:-1]) >= -1e-6)  # Allow small numerical errors

    # Test consistency with icdf (inverse CDF)
    quantiles = jnp.array([0.1, 0.3, 0.5, 0.7, 0.9])
    icdf_values = truncated_dist.icdf(quantiles)
    recovered_quantiles = truncated_dist.cdf(icdf_values)
    assert_allclose(recovered_quantiles, quantiles, atol=1e-5)


@pytest.mark.parametrize("loc, scale", [(0.0, 1.0), (2.0, 0.5), (-1.0, 2.0)])
@pytest.mark.parametrize(
    "low, high", [(-2.0, 2.0), (-1.0, 1.0), (0.0, 3.0), (None, 2.0), (-2.0, None)]
)
def test_truncated_normal_cdf_scipy_consistency(loc, scale, low, high):
    """Test consistency with scipy truncated normal CDF."""
    from jax.scipy.stats import truncnorm as jax_truncnorm

    # Create truncated normal distribution
    if low is None and high is None:
        pytest.skip("Cannot test when both bounds are None")

    if low is None:
        truncated_dist = dist.RightTruncatedDistribution(dist.Normal(loc, scale), high)
        a = -jnp.inf
        b = (high - loc) / scale
    elif high is None:
        truncated_dist = dist.LeftTruncatedDistribution(dist.Normal(loc, scale), low)
        a = (low - loc) / scale
        b = jnp.inf
    else:
        truncated_dist = dist.TwoSidedTruncatedDistribution(
            dist.Normal(loc, scale), low, high
        )
        a = (low - loc) / scale
        b = (high - loc) / scale

    # Test values within the truncation range
    if low is None:
        test_values = jnp.linspace(high - 3 * scale, high - 0.1 * scale, 10)
    elif high is None:
        test_values = jnp.linspace(low + 0.1 * scale, low + 3 * scale, 10)
    else:
        test_values = jnp.linspace(
            low + 0.1 * (high - low), high - 0.1 * (high - low), 10
        )

    # Compare CDFs
    numpyro_cdf = truncated_dist.cdf(test_values)
    jax_cdf = jax_truncnorm.cdf(test_values, a=a, b=b, loc=loc, scale=scale)

    assert_allclose(numpyro_cdf, jax_cdf, rtol=1e-5, atol=1e-6)


def test_truncated_cdf_edge_cases():
    """Test edge cases for truncated distribution CDFs."""
    base_dist = dist.Normal(0.0, 1.0)

    # Test with extreme truncation points
    left_truncated = dist.LeftTruncatedDistribution(base_dist, 5.0)  # Far in the tail
    right_truncated = dist.RightTruncatedDistribution(
        base_dist, -5.0
    )  # Far in the tail
    two_sided = dist.TwoSidedTruncatedDistribution(base_dist, -0.1, 0.1)  # Very narrow

    # Test that CDFs are well-behaved
    test_values = jnp.array([-10.0, 0.0, 10.0])

    left_cdf = left_truncated.cdf(test_values)
    assert jnp.all(jnp.isfinite(left_cdf))
    assert jnp.all(left_cdf >= 0.0) and jnp.all(left_cdf <= 1.0)

    right_cdf = right_truncated.cdf(test_values)
    assert jnp.all(jnp.isfinite(right_cdf))
    assert jnp.all(right_cdf >= 0.0) and jnp.all(right_cdf <= 1.0)

    two_sided_cdf = two_sided.cdf(test_values)
    assert jnp.all(jnp.isfinite(two_sided_cdf))
    assert jnp.all(two_sided_cdf >= 0.0) and jnp.all(two_sided_cdf <= 1.0)


@pytest.mark.parametrize("batch_shape", [(), (3,)])
def test_truncated_cdf_batch_shapes(batch_shape):
    """Test that CDF works correctly with batch shapes."""
    if batch_shape == ():
        loc = 0.0
        scale = 1.0
        low = -1.0
        high = 1.0
    else:
        loc = jnp.zeros(batch_shape)
        scale = jnp.ones(batch_shape)
        low = -jnp.ones(batch_shape)
        high = jnp.ones(batch_shape)

    base_dist = dist.Normal(loc, scale)
    truncated_dist = dist.TwoSidedTruncatedDistribution(base_dist, low, high)

    # Test with single value
    value = 0.0
    cdf_value = truncated_dist.cdf(value)
    assert cdf_value.shape == batch_shape

    # Test with multiple values - these should broadcast properly
    if batch_shape == ():
        values = jnp.array([-2.0, 0.0, 2.0])
        cdf_values = truncated_dist.cdf(values)
        expected_shape = values.shape
        assert cdf_values.shape == expected_shape
    else:
        # For batched case, test with single values to avoid broadcasting issues
        for value in [-2.0, 0.0, 2.0]:
            cdf_value = truncated_dist.cdf(value)
            assert cdf_value.shape == batch_shape


param_cens_dist = pytest.mark.parametrize(
    "base_dist_class, base_params",
    [
        (dist.Normal, (0.0, 1.0)),
        (dist.Normal, (2.0, 0.5)),
        (dist.Cauchy, (0.0, 1.0)),
        (dist.Laplace, (0.0, 1.0)),
        (dist.Logistic, (0.0, 1.0)),
        (dist.StudentT, (2.0, 0.0, 1.0)),
        (dist.HalfNormal, (1.0,)),
        (dist.Poisson, (1.0,)),
        (dist.GammaPoisson, (1, 1)),
        (dist.Weibull, (1.0, 1.0)),
    ],
)
param_censored = pytest.mark.parametrize("censored", [0.0, 1.0])
param_left_censored = pytest.mark.parametrize("left_censored", [0.0, 1.0])
param_right_censored = pytest.mark.parametrize("right_censored", [0.0, 1.0])


@param_cens_dist
@param_censored
def test_left_censored_logprob(base_dist_class, base_params, censored):
    """Test log_prob for left censored distributions."""
    base_dist = base_dist_class(*base_params)
    censored_dist = dist.LeftCensoredDistribution(base_dist, censored)

    # Test points
    test_values = base_dist.support.feasible_like(jnp.zeros((1,)))

    # Compute log_prob
    logp_values = censored_dist.log_prob(test_values)

    # Basic properties
    assert logp_values.shape == test_values.shape
    assert jnp.all(jnp.isfinite(logp_values))

    # for noncensored values, log_prob should match base distribution
    base_logp_values = base_dist.log_prob(test_values)
    # for censored values, log_prob should be log CDF of base distribution
    dtype = jnp.result_type(test_values, float)
    minval = 100.0 * jnp.finfo(dtype).tiny
    cdf_values = jnp.log(jnp.clip(base_dist.cdf(jnp.array(test_values)), minval, 1.0))
    base_diff = jnp.where(
        censored, logp_values - cdf_values, logp_values - base_logp_values
    )
    assert jnp.abs(base_diff).max() < 1e-6


@param_cens_dist
@param_censored
def test_right_censored_logprob(base_dist_class, base_params, censored):
    """Test log_prob for right censored distributions."""
    base_dist = base_dist_class(*base_params)
    censored_dist = dist.RightCensoredDistribution(base_dist, censored)

    # Test points
    test_values = base_dist.support.feasible_like(jnp.zeros((1,)))

    # Compute log_prob
    logp_values = censored_dist.log_prob(test_values)

    # Basic properties
    assert logp_values.shape == test_values.shape
    assert jnp.all(jnp.isfinite(logp_values))

    # for noncensored values, log_prob should match base distribution
    base_logp_values = base_dist.log_prob(test_values)
    # for censored values, log_prob should be log 1 - CDF of base distribution
    logS_values = jnp.log1p(-base_dist.cdf(test_values))
    base_diff = jnp.where(
        censored, logp_values - logS_values, logp_values - base_logp_values
    )
    assert jnp.abs(base_diff).max() < 1e-6


@param_cens_dist
@param_left_censored
@param_right_censored
def test_interval_censored_logprob(
    base_dist_class, base_params, left_censored, right_censored
):
    """Test log_prob for interval censored distributions."""
    base_dist = base_dist_class(*base_params)
    interval_censored_dist = dist.IntervalCensoredDistribution(
        base_dist, left_censored, right_censored
    )
    left_censored_dist = dist.LeftCensoredDistribution(base_dist, 1.0)
    right_censored_dist = dist.RightCensoredDistribution(base_dist, 1.0)

    # Test points
    test_values_lower = base_dist.support.feasible_like(jnp.zeros((1,)))
    test_values_upper = test_values_lower + 1
    test_values = jnp.stack([test_values_lower, test_values_upper], axis=-1)

    # Compute log_prob
    logp_values = interval_censored_dist.log_prob(test_values)
    # note: for left censoring, the value is cdf up to right point
    logp_values_left_lower = left_censored_dist.log_prob(test_values_lower)
    logp_values_left_upper = left_censored_dist.log_prob(test_values_upper)
    # note: for right censoring, the value is cdf from left point up to infinity
    logp_values_right_lower = right_censored_dist.log_prob(test_values_lower)
    logp_values_right_upper = right_censored_dist.log_prob(test_values_upper)

    # Basic properties
    assert logp_values.shape == test_values_lower.shape
    assert jnp.all(jnp.isfinite(logp_values))

    lc = jnp.array(left_censored, dtype=bool)
    rc = jnp.array(right_censored, dtype=bool)
    m_left = lc & ~rc
    m_right = rc & ~lc

    # For left censored points, log_prob should match left censored distribution, evaluated at right point
    # For right censored points, log_prob should match right censored distribution, evaluated at left point
    left_diff = jnp.where(m_left, logp_values - logp_values_left_upper, 0.0)
    right_diff = jnp.where(m_right, logp_values - logp_values_right_lower, 0.0)
    assert jnp.abs(left_diff).max() < 1e-6
    assert jnp.abs(right_diff).max() < 1e-6

    # for interval censored points, log_prob should be log CDF between left and right points,
    # evaluated at respective points
    # logp_left(x) = log CDF(x)
    # logp_right(y) = log (1 - CDF(y)) = log S(y)
    # logp_interval(x, y) = log (CDF(y) - CDF(x)) = log (exp(logp_left(y)) - exp(logp_right(x)))
    m_interval = (~lc) & (~rc)
    interval_comparison = jnp.log1p(
        -(jnp.exp(logp_values_right_upper) + jnp.exp(logp_values_left_lower))
    )
    interval_diff = jnp.where(m_interval, logp_values - interval_comparison, 0.0)
    assert jnp.abs(interval_diff).max() < 1e-5

    # for double censored points, log_prob should be log (1 - CDF between left and right points)
    # logp_double(x, y) = log (1 - (CDF(y) - CDF(x))) = log (1 - CDF(y) + CDF(x))
    m_double = lc & rc
    double_comparison = jnp.log(
        jnp.exp(logp_values_right_upper) + jnp.exp(logp_values_left_lower)
    )
    double_diff = jnp.where(m_double, logp_values - double_comparison, 0.0)
    assert jnp.abs(double_diff).max() < 1e-5


def test_censored_logprob_edge_cases():
    """Test edge cases for censored distributions."""
    base_dist = dist.Normal(0.0, 1.0)

    # Test with extreme censored points
    left_censored = dist.LeftCensoredDistribution(base_dist, 1)
    right_censored = dist.RightCensoredDistribution(base_dist, 1)

    # Test that logprobs are well-behaved on extreme values
    test_values = jnp.array([-10.0, 0.0, 10.0])

    left_logprob = left_censored.log_prob(test_values)
    assert jnp.all(jnp.isfinite(left_logprob))

    right_logprob = right_censored.log_prob(test_values)
    assert jnp.all(jnp.isfinite(right_logprob))


@param_left_censored
@param_right_censored
def test_interval_censored_logprob_edge_cases(left_censored, right_censored):
    """Test edge cases for interval censored distributions."""
    base_dist = dist.Normal(0.0, 1.0)

    # Test cases for interval censored data
    interval_censored = dist.IntervalCensoredDistribution(
        base_dist, left_censored, right_censored
    )
    test_values_interval = jnp.array(
        [
            [-10.0, -5.0],  # very negative
            [0.0, 0.5],  # small difference
            [0.0, 0.0],  # exact observation
            [5.0, 10.0],
        ]
    )  # very positive
    interval_logprob = interval_censored.log_prob(test_values_interval)
    assert jnp.all(jnp.isfinite(interval_logprob))


@pytest.mark.parametrize("batch_shape", [(), (3,)])
def test_censored_logprob_batch_shapes(batch_shape):
    """Test that log_prob works correctly with batch shapes."""
    if batch_shape == ():
        loc = 0.0
        scale = 1.0
        censored = 1.0
    else:
        loc = jnp.zeros(batch_shape)
        scale = jnp.ones(batch_shape)
        censored = jnp.ones(batch_shape)

    base_dist = dist.Normal(loc, scale)
    censored_dist = dist.RightCensoredDistribution(base_dist, censored)

    # Test with single value
    value = 0.0
    logp_value = censored_dist.log_prob(value)
    assert logp_value.shape == batch_shape

    # Test with multiple values - these should broadcast properly
    if batch_shape == ():
        values = jnp.array([-2.0, 0.0, 2.0])
        logp_values = censored_dist.log_prob(values)
        expected_shape = values.shape
        assert logp_values.shape == expected_shape
    else:
        # For batched case, test with single values to avoid broadcasting issues
        for value in [-2.0, 0.0, 2.0]:
            logp_value = censored_dist.log_prob(value)
            assert logp_value.shape == batch_shape


@pytest.mark.parametrize("batch_shape", [(), (3,)])
@param_left_censored
@param_right_censored
def test_interval_censored_logprob_batch_shapes(
    batch_shape, left_censored, right_censored
):
    """Test that log_prob works correctly with batch shapes."""
    if batch_shape == ():
        loc = 0.0
        scale = 1.0
        left_censored = 1.0
        right_censored = 1.0
    else:
        loc = jnp.zeros(batch_shape)
        scale = jnp.ones(batch_shape)
        left_censored = jnp.broadcast_to(left_censored, batch_shape)
        right_censored = jnp.broadcast_to(right_censored, batch_shape)

    base_dist = dist.Normal(loc, scale)
    censored_dist = dist.IntervalCensoredDistribution(
        base_dist, left_censored, right_censored
    )

    # Test with single value
    value_lower = 0.0
    value_upper = 1.0
    value = jnp.stack([value_lower, value_upper], axis=-1)
    logp_value = censored_dist.log_prob(value)
    assert logp_value.shape == batch_shape

    # Test with multiple values - these should broadcast properly
    if batch_shape == ():
        values_lower = jnp.array([-2.0, 0.0, 2.0])
        values_upper = values_lower + 1.0
        values = jnp.stack([values_lower, values_upper], axis=-1)
        logp_values = censored_dist.log_prob(values)
        expected_shape = values.shape[
            :-1
        ]  # input always has (lower, upper) bounds, log_prob does not
        assert logp_values.shape == expected_shape
    else:
        # For batched case, test with single values to avoid broadcasting issues
        for value_lower in [-2.0, 0.0, 2.0]:
            value_upper = value_lower + 1.0
            value = jnp.stack([value_lower, value_upper], axis=-1)
            logp_value = censored_dist.log_prob(value)
            assert logp_value.shape == batch_shape


def test_censored_broadcasting():
    # loc: (3,), censored: (2, 1) -> batch_shape: (2, 3)
    base_dist = dist.Normal(jnp.array([0.0, 1.0, 2.0]), 1.0)
    censored = jnp.array([[0], [1]])
    dist_obj = dist.LeftCensoredDistribution(base_dist, censored)
    assert dist_obj.batch_shape == (2, 3)

    # interval censored
    dist_obj = dist.IntervalCensoredDistribution(base_dist, censored, censored)
    assert dist_obj.batch_shape == (2, 3)

    # loc: (2, 1), censored: (3,) -> batch_shape: (2, 3)
    base_dist = dist.Normal(jnp.array([[0], [1]]), 1.0)
    censored = jnp.array([0.0, 1.0, 2.0])
    dist_obj = dist.LeftCensoredDistribution(base_dist, censored)
    assert dist_obj.batch_shape == (2, 3)

    # interval censored
    dist_obj = dist.IntervalCensoredDistribution(base_dist, censored, censored)
    assert dist_obj.batch_shape == (2, 3)


@param_left_censored
@param_right_censored
def test_censored_gradient(left_censored, right_censored):
    value = jnp.array([-1.0, 0.0, 1.0])

    def loss_loc_left(loc):
        base_dist = dist.Normal(loc, 1.0)
        censored_dist = dist.LeftCensoredDistribution(base_dist, left_censored)
        return censored_dist.log_prob(value).sum()

    grad_fn = jax.grad(loss_loc_left)
    grad = grad_fn(0.0)
    assert jnp.isfinite(grad)

    def loss_scale_left(scale):
        base_dist = dist.Normal(0.0, scale)
        censored_dist = dist.LeftCensoredDistribution(base_dist, left_censored)
        return censored_dist.log_prob(value).sum()

    grad_fn = jax.grad(loss_scale_left)
    grad = grad_fn(1.0)
    assert jnp.isfinite(grad)

    def loss_loc_right(loc):
        base_dist = dist.Normal(loc, 1.0)
        censored_dist = dist.RightCensoredDistribution(base_dist, right_censored)
        return censored_dist.log_prob(value).sum()

    grad_fn = jax.grad(loss_loc_right)
    grad = grad_fn(0.0)
    assert jnp.isfinite(grad)

    def loss_scale_right(scale):
        base_dist = dist.Normal(0.0, scale)
        censored_dist = dist.RightCensoredDistribution(base_dist, right_censored)
        return censored_dist.log_prob(value).sum()

    grad_fn = jax.grad(loss_scale_right)
    grad = grad_fn(1.0)
    assert jnp.isfinite(grad)

    # interval censored takes observations with (lower, upper)
    value_interval = jnp.stack([value, value + 1], axis=-1)

    def loss_loc_interval(loc):
        base_dist = dist.Normal(loc, 1.0)
        censored_dist = dist.IntervalCensoredDistribution(
            base_dist, left_censored, right_censored
        )
        return censored_dist.log_prob(value_interval).sum()

    grad_fn = jax.grad(loss_loc_interval)
    grad = grad_fn(0.0)
    assert jnp.isfinite(grad)

    def loss_scale_interval(scale):
        base_dist = dist.Normal(0.0, scale)
        censored_dist = dist.IntervalCensoredDistribution(
            base_dist, left_censored, right_censored
        )
        return censored_dist.log_prob(value_interval).sum()

    grad_fn = jax.grad(loss_scale_interval)
    grad = grad_fn(1.0)
    assert jnp.isfinite(grad)


@param_left_censored
@param_right_censored
def test_censored_sample_validity(left_censored, right_censored):
    rng_key = random.PRNGKey(0)
    sample_shape = (1000,)

    base_dist = dist.Normal(0.0, 1.0)
    base_samples = base_dist.sample(rng_key, sample_shape)

    # Samples should match base distribution whether censored = 0 or 1
    right_censored_dist = dist.RightCensoredDistribution(base_dist, right_censored)
    right_censored_samples = right_censored_dist.sample(rng_key, sample_shape)

    assert all(jnp.abs(base_samples - right_censored_samples) < 1e-6)

    left_censored_dist = dist.LeftCensoredDistribution(base_dist, left_censored)
    left_censored_samples = left_censored_dist.sample(rng_key, sample_shape)

    assert all(jnp.abs(base_samples - left_censored_samples) < 1e-6)

    interval_censored_dist = dist.IntervalCensoredDistribution(
        base_dist, left_censored, right_censored
    )
    interval_censored_samples = interval_censored_dist.sample(rng_key, sample_shape)

    assert all(jnp.abs(base_samples - interval_censored_samples) < 1e-6)


def test_censored_sample_shape():
    # Check sample shapes also when broadcasting
    rng_key = random.PRNGKey(0)
    censored = jnp.array([0.0, 1.0]).reshape((1, 2))
    base_dist = dist.Normal(jnp.zeros((3, 1)))
    sample_shape = (10, 4)
    expected_shape = sample_shape + (3, 2)

    censored_dist = dist.LeftCensoredDistribution(base_dist, censored)
    samples = censored_dist.sample(rng_key, sample_shape)
    assert samples.shape == expected_shape

    censored_dist = dist.RightCensoredDistribution(base_dist, censored)
    samples = censored_dist.sample(rng_key, sample_shape)
    assert samples.shape == expected_shape

    censored_dist = dist.IntervalCensoredDistribution(base_dist, censored, censored)
    samples = censored_dist.sample(rng_key, sample_shape)
    assert samples.shape == expected_shape


@pytest.mark.parametrize(
    "left_censored, right_censored, lower, upper, censoring_type",
    [
        # left censored examples
        (1.0, 0.0, 0.001, 1.0, "left"),
        (1.0, 0.0, 0.001, 0.001, "left"),
        # right censored examples
        (0.0, 1.0, 0.001, 1.0, "right"),
        (0.0, 1.0, 0.001, 0.001, "right"),
        # interval censored example
        (0.0, 0.0, 0.001, 1.0, "interval"),
        # # doubly censored example
        (1.0, 1.0, 0.001, 1.0, "double"),
        # exact example
        (0.0, 0.0, 0.001, 0.001, "exact"),
    ],
)
def test_interval_censored_masks(
    left_censored, right_censored, lower, upper, censoring_type
):
    base_dist = dist.HalfNormal()
    censored_dist = dist.IntervalCensoredDistribution(
        base_dist,
        left_censored,
        right_censored,
        validate_args=True,
    )
    value = jnp.array([[lower, upper]])
    m_left, m_right, m_interval, m_doubly, m_exact = censored_dist._get_censoring_masks(
        value
    )
    # assert that always exactly one mask is true
    assert m_left + m_right + m_interval + m_doubly + m_exact == 1

    if censoring_type == "left":
        assert m_left
    elif censoring_type == "right":
        assert m_right
    elif censoring_type == "interval":
        assert m_interval
    elif censoring_type == "double":
        assert m_doubly
    elif censoring_type == "exact":
        assert m_exact


@pytest.mark.parametrize(
    "left_censored, right_censored, lower, upper, should_raise",
    [
        # left censored examples
        (1.0, 0.0, 0.001, 1.0, False),
        (1.0, 0.0, 0.001, -1.0, True),
        (1.0, 0.0, -0.001, 1.0, False),
        (1.0, 0.0, -jnp.inf, 1.0, False),
        (1.0, 0.0, jnp.nan, 1.0, False),
        # right censored examples
        (0.0, 1.0, 0.001, 1.0, False),
        (0.0, 1.0, 0.001, -1.0, False),
        (0.0, 1.0, -1.0, 0.001, True),
        (0.0, 1.0, 1.0, -jnp.inf, False),
        (0.0, 1.0, 1.0, jnp.nan, False),
        # interval, doubly, exact examples: both bounds valid
        # interval censored examples
        (0.0, 0.0, 0.001, 1.0, False),
        (0.0, 0.0, -0.001, 1.0, True),
        (0.0, 0.0, 0.001, -1.0, True),
        # doubly censored examples
        (1.0, 1.0, 0.001, 1.0, False),
        (1.0, 1.0, -0.001, 1.0, True),
        (1.0, 1.0, 0.001, -1.0, True),
        # exact examples
        (0.0, 0.0, 0.001, 0.001, False),
        (0.0, 0.0, -0.001, -0.001, True),
        # interval and doubly censored, upper should be >= lower
        (0.0, 0.0, 0.001, 0.002, False),
        (1.0, 1.0, 0.001, 0.002, False),
        (0.0, 0.0, 0.002, 0.001, True),
        (1.0, 1.0, 0.002, 0.001, True),
    ],
)
def test_interval_censored_validate_sample(
    left_censored, right_censored, lower, upper, should_raise
):
    base_dist = dist.HalfNormal()
    censored_dist = dist.IntervalCensoredDistribution(
        base_dist,
        left_censored,
        right_censored,
        validate_args=True,
    )
    value = jnp.array([[lower, upper]])
    if should_raise:
        with pytest.raises(UserWarning):
            censored_dist.log_prob(value)
    else:
        censored_dist.log_prob(value)  # Should not raise


@pytest.mark.parametrize(
    argnames="concentration1,concentration0,value",
    argvalues=[
        (1.0, 8.0, 0.0),
        (8.0, 1.0, 1.0),
    ],
    ids=["Beta(1,8) at x=0", "Beta(8,1) at x=1"],
)
def test_beta_logprob_edge_cases(concentration1, concentration0, value):
    """Test Beta distribution with concentration=1 gives finite log probability at boundary."""
    beta_dist = dist.Beta(concentration1, concentration0)
    log_prob = beta_dist.log_prob(value)

    assert not jnp.isnan(log_prob), (
        f"Beta({concentration1},{concentration0}).log_prob({value}) should not be NaN"
    )
    assert jnp.isfinite(log_prob), (
        f"Beta({concentration1},{concentration0}).log_prob({value}) should be finite"
    )


def test_beta_logprob_edge_case_consistency_small_values():
    """Test that edge case values are consistent with small deviation values."""
    beta_dist = dist.Beta(1.0, 8.0)
    beta_dist2 = dist.Beta(8.0, 1.0)

    # At boundary
    log_prob_at_zero = beta_dist.log_prob(0.0)
    log_prob_at_one = beta_dist2.log_prob(1.0)

    # Very close to boundary
    small_value = 1e-10
    log_prob_small = beta_dist.log_prob(small_value)
    log_prob_close_to_one = beta_dist2.log_prob(1.0 - small_value)

    # Edge case values should be close to small deviation values
    assert jnp.abs(log_prob_at_zero - log_prob_small) < 1e-5
    assert jnp.abs(log_prob_at_one - log_prob_close_to_one) < 1e-5


def test_beta_logprob_edge_case_non_boundary_values():
    """Test that Beta with concentration=1 still works for non-boundary values."""
    beta_dist = dist.Beta(1.0, 8.0)
    beta_dist2 = dist.Beta(8.0, 1.0)

    assert jnp.isfinite(beta_dist.log_prob(0.5))
    assert jnp.isfinite(beta_dist2.log_prob(0.5))


def test_beta_logprob_boundary_non_edge_cases():
    """Test that non-edge cases (concentration > 1) still give -inf at boundaries."""
    beta_dist3 = dist.Beta(2.0, 8.0)
    beta_dist4 = dist.Beta(8.0, 2.0)

    assert jnp.isneginf(beta_dist3.log_prob(0.0))
    assert jnp.isneginf(beta_dist4.log_prob(1.0))


@pytest.mark.parametrize(
    argnames="concentration1,concentration0,value,grad_param,grad_value",
    argvalues=[
        (1.0, 8.0, 0.0, "value", 0.0),
        (8.0, 1.0, 1.0, "value", 1.0),
        (1.0, 8.0, 0.0, "concentration1", 1.0),
        (1.0, 8.0, 0.0, "concentration0", 8.0),
        (8.0, 1.0, 1.0, "concentration1", 8.0),
        (8.0, 1.0, 1.0, "concentration0", 1.0),
    ],
    ids=[
        "Beta(1,8) at x=0",
        "Beta(8,1) at x=1",
        "Beta(1,8) at concentration1=1",
        "Beta(1,8) at concentration0=8",
        "Beta(8,1) at concentration1=8",
        "Beta(8,1) at concentration0=1",
    ],
)
def test_beta_gradient_edge_cases_single_param(
    concentration1, concentration0, value, grad_param, grad_value
):
    """Test that gradients w.r.t. individual parameters are finite at edge cases."""
    if grad_param == "value":

        def log_prob_fn(x):
            return dist.Beta(concentration1, concentration0).log_prob(x)

        grad = jax.grad(log_prob_fn)(value)
    elif grad_param == "concentration1":

        def log_prob_fn(c1):
            return dist.Beta(c1, concentration0).log_prob(value)

        grad = jax.grad(log_prob_fn)(grad_value)
    else:  # concentration0

        def log_prob_fn(c0):
            return dist.Beta(concentration1, c0).log_prob(value)

        grad = jax.grad(log_prob_fn)(grad_value)

    assert jnp.isfinite(grad), (
        f"Gradient w.r.t. {grad_param} for Beta({concentration1},{concentration0}) "
        f"at x={value} should be finite"
    )


@pytest.mark.parametrize(
    argnames="concentration1,concentration0,value",
    argvalues=[
        (1.0, 8.0, 0.0),
        (8.0, 1.0, 1.0),
    ],
    ids=["Beta(1,8) at x=0", "Beta(8,1) at x=1"],
)
def test_beta_gradient_edge_cases_all_params(concentration1, concentration0, value):
    """Test that all gradients are finite when computed simultaneously at edge cases."""

    def log_prob_fn(params):
        c1, c0, v = params
        return dist.Beta(c1, c0).log_prob(v)

    grads = jax.grad(log_prob_fn)(jnp.array([concentration1, concentration0, value]))
    assert jnp.all(jnp.isfinite(grads)), (
        f"All gradients for Beta({concentration1},{concentration0}) at x={value} "
        f"should be finite"
    )


def test_uniform_log_prob_outside_support():
    from numpyro.distributions.distribution import enable_validation

    enable_validation()

    d = dist.Uniform(0, 1)
    with pytest.warns(
        UserWarning,
        match="Out-of-support values provided to log prob method. The value argument should be within the support.",
    ):
        d.log_prob(-0.5)
